{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "48HjPXSxsiSO",
    "outputId": "1d36f06e-16a4-42bc-eba7-dc7d249589d7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7de1f01bc150>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_dataset\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "torch.manual_seed(12046)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "Hn9ypPW0siSP"
   },
   "outputs": [],
   "source": [
    "# 一些超参数\n",
    "context_length = 10\n",
    "learning_rate = 0.01\n",
    "eval_iters = 10\n",
    "batch_size=1000\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "P_-nzG89siSQ",
    "outputId": "4329e089-bae9-417b-e67f-7cef29c9d1d4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def to_arrow_schema(schema):\n",
      "    \"\"\" Convert a schema from Spark to Arrow\n",
      "    \"\"\"\n",
      "    import pyarrow as pa\n",
      "    fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)\n",
      "              for field in schema]\n",
      "    return pa.schema(fields)\n",
      "['def to_arrow_schema(schema):\\n    \"\"\" Convert a schema from Spark to Arrow\\n    \"\"\"\\n    import pyarrow as pa\\n    fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)\\n              for field in schema]\\n    return pa.schema(fields)', 'def from_arrow_type(at):\\n    \"\"\" Convert pyarrow type to Spark data type.\\n    \"\"\"\\n    import pyarrow.types as types\\n    if types.is_boolean(at):\\n        spark_type = BooleanType()\\n    elif types.is_int8(at):\\n        spark_type = ByteType()\\n    elif types.is_int16(at):\\n        spark_type = ShortType()\\n    elif types.is_int32(at):\\n        spark_type = IntegerType()\\n    elif types.is_int64(at):\\n        spark_type = LongType()\\n    elif types.is_float32(at):\\n        spark_type = FloatType()\\n    elif types.is_float64(at):\\n        spark_type = DoubleType()\\n    elif types.is_decimal(at):\\n        spark_type = DecimalType(precision=at.precision, scale=at.scale)\\n    elif types.is_string(at):\\n        spark_type = StringType()\\n    elif types.is_binary(at):\\n        spark_type = BinaryType()\\n    elif types.is_date32(at):\\n        spark_type = DateType()\\n    elif types.is_timestamp(at):\\n        spark_type = TimestampType()\\n    elif types.is_list(at):\\n        if types.is_timestamp(at.value_type):\\n            raise TypeError(\"Unsupported type in conversion from Arrow: \" + str(at))\\n        spark_type = ArrayType(from_arrow_type(at.value_type))\\n    elif types.is_struct(at):\\n        if any(types.is_struct(field.type) for field in at):\\n            raise TypeError(\"Nested StructType not supported in conversion from Arrow: \" + str(at))\\n        return StructType(\\n            [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)\\n             for field in at])\\n    else:\\n        raise TypeError(\"Unsupported type in conversion from Arrow: \" + str(at))\\n    return spark_type']\n"
     ]
    }
   ],
   "source": [
    "raw_datasets = load_dataset('code_search_net', 'python')\n",
    "datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])\n",
    "# 通过索引提取datasets数据的时候，返回一个dict，其中的value是一个字符串\n",
    "print(datasets[8]['whole_func_string'])\n",
    "# 当传入的是一个数组时，返回的依然是一个dict，但其中的value是一个列表\n",
    "print(datasets[8: 10]['whole_func_string'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "Ky6gh_4TsiSQ"
   },
   "outputs": [],
   "source": [
    "class char_tokenizer:\n",
    "\n",
    "    def __init__(self, data, begin_ind=0, end_ind=1):\n",
    "        # 数据中出现的所有字符构成字典\n",
    "        chars = sorted(list(set(''.join(data))))\n",
    "        # 预留两个位置给开头和结尾的特殊字符\n",
    "        self.char2ind = {s : i + 2 for i, s in enumerate(chars)}\n",
    "        self.char2ind['<|b|>'] = begin_ind\n",
    "        self.char2ind['<|e|>'] = end_ind\n",
    "        self.begin_ind = begin_ind\n",
    "        self.end_ind = end_ind\n",
    "        self.ind2char = {i : s for s, i in self.char2ind.items()}\n",
    "\n",
    "    def encode(self, text):\n",
    "        '''\n",
    "        编码\n",
    "        参数\n",
    "        ----\n",
    "        text ：str，文本\n",
    "        '''\n",
    "        return [self.char2ind[c] for c in text]\n",
    "\n",
    "    def decode(self, enc):\n",
    "        '''\n",
    "        解码\n",
    "        参数\n",
    "        ----\n",
    "        enc ：int or list[int]\n",
    "        '''\n",
    "        if isinstance(enc, int):\n",
    "            return self.ind2char[enc]\n",
    "        return [self.ind2char[i] for i in enc]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yDS5elR7siSR",
    "outputId": "5c2e18d7-3f81-4873-a14d-c8a7b900a0d1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('def postappend(self):', 99)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 举例验证分词器\n",
    "tok = char_tokenizer(datasets['whole_func_string'])\n",
    "example_text = 'def postappend(self):'\n",
    "''.join(tok.decode(tok.encode(example_text))), len(tok.char2ind)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "egLVxc1VsiSR",
    "outputId": "5d0b0760-2267-47d8-93f1-9bb326f25719"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|> ---> d\n",
      "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|>d ---> e\n",
      "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|>de ---> f\n",
      "<|b|><|b|><|b|><|b|><|b|><|b|><|b|>def --->  \n",
      "<|b|><|b|><|b|><|b|><|b|><|b|>def  ---> p\n",
      "<|b|><|b|><|b|><|b|><|b|>def p ---> o\n",
      "<|b|><|b|><|b|><|b|>def po ---> s\n",
      "<|b|><|b|><|b|>def pos ---> t\n",
      "<|b|><|b|>def post ---> a\n",
      "<|b|>def posta ---> p\n",
      "def postap ---> p\n",
      "ef postapp ---> e\n",
      "f postappe ---> n\n",
      " postappen ---> d\n",
      "postappend ---> (\n",
      "ostappend( ---> s\n",
      "stappend(s ---> e\n",
      "tappend(se ---> l\n",
      "append(sel ---> f\n",
      "ppend(self ---> )\n",
      "pend(self) ---> :\n",
      "end(self): ---> <|e|>\n"
     ]
    }
   ],
   "source": [
    "def autoregressive_trans(text, tokenizer, context_length=context_length):\n",
    "    '''\n",
    "    将文本转换成一系列的训练数据\n",
    "    参数\n",
    "    ----\n",
    "    text ：str，文本\n",
    "    tokenizer ：分词器\n",
    "    context_length ：int，背景文本的长度\n",
    "    返回\n",
    "    ----\n",
    "    inputs ：list[list[int]]，背景文本（特征）\n",
    "    labels ：list[list[int]]，预测标签\n",
    "    '''\n",
    "    inputs, labels = [], []\n",
    "    b_ind = tokenizer.begin_ind\n",
    "    e_ind = tokenizer.end_ind\n",
    "    enc = tokenizer.encode(text)\n",
    "    # 增加开始和结尾的特殊字符\n",
    "    x = [b_ind] * context_length + enc + [e_ind]\n",
    "    for i in range(len(x) - context_length):\n",
    "        inputs.append(x[i: i + context_length])\n",
    "        labels.append(x[i + context_length])\n",
    "    return inputs, labels\n",
    "\n",
    "# 举例展示自回归模式的训练数据\n",
    "inputs, labels = autoregressive_trans(example_text, tok)\n",
    "for a, b in zip(inputs, labels):\n",
    "    print(''.join(tok.decode(a)), '--->',  tok.decode(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7kOWxKCisiSR",
    "outputId": "82ab4276-dc85-42e2-c6a9-e2e99629928f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'inputs': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "  [0, 0, 0, 0, 0, 0, 0, 0, 0, 71],\n",
       "  [0, 0, 0, 0, 0, 0, 0, 0, 71, 72],\n",
       "  [0, 0, 0, 0, 0, 0, 0, 71, 72, 73],\n",
       "  [0, 0, 0, 0, 0, 0, 71, 72, 73, 3],\n",
       "  [0, 0, 0, 0, 0, 71, 72, 73, 3, 87],\n",
       "  [0, 0, 0, 0, 71, 72, 73, 3, 87, 82],\n",
       "  [0, 0, 0, 71, 72, 73, 3, 87, 82, 66],\n",
       "  [0, 0, 71, 72, 73, 3, 87, 82, 66, 68],\n",
       "  [0, 71, 72, 73, 3, 87, 82, 66, 68, 85],\n",
       "  [71, 72, 73, 3, 87, 82, 66, 68, 85, 85],\n",
       "  [72, 73, 3, 87, 82, 66, 68, 85, 85, 82],\n",
       "  [73, 3, 87, 82, 66, 68, 85, 85, 82, 90],\n",
       "  [3, 87, 82, 66, 68, 85, 85, 82, 90, 66],\n",
       "  [87, 82, 66, 68, 85, 85, 82, 90, 66, 86],\n",
       "  [82, 66, 68, 85, 85, 82, 90, 66, 86, 70],\n",
       "  [66, 68, 85, 85, 82, 90, 66, 86, 70, 75],\n",
       "  [68, 85, 85, 82, 90, 66, 86, 70, 75, 72],\n",
       "  [85, 85, 82, 90, 66, 86, 70, 75, 72, 80],\n",
       "  [85, 82, 90, 66, 86, 70, 75, 72, 80, 68],\n",
       "  [82, 90, 66, 86, 70, 75, 72, 80, 68, 11],\n",
       "  [90, 66, 86, 70, 75, 72, 80, 68, 11, 86],\n",
       "  [66, 86, 70, 75, 72, 80, 68, 11, 86, 70],\n",
       "  [86, 70, 75, 72, 80, 68, 11, 86, 70, 75],\n",
       "  [70, 75, 72, 80, 68, 11, 86, 70, 75, 72],\n",
       "  [75, 72, 80, 68, 11, 86, 70, 75, 72, 80],\n",
       "  [72, 80, 68, 11, 86, 70, 75, 72, 80, 68],\n",
       "  [80, 68, 11, 86, 70, 75, 72, 80, 68, 12],\n",
       "  [68, 11, 86, 70, 75, 72, 80, 68, 12, 29],\n",
       "  [11, 86, 70, 75, 72, 80, 68, 12, 29, 2],\n",
       "  [86, 70, 75, 72, 80, 68, 12, 29, 2, 3],\n",
       "  [70, 75, 72, 80, 68, 12, 29, 2, 3, 3],\n",
       "  [75, 72, 80, 68, 12, 29, 2, 3, 3, 3],\n",
       "  [72, 80, 68, 12, 29, 2, 3, 3, 3, 3],\n",
       "  [80, 68, 12, 29, 2, 3, 3, 3, 3, 5],\n",
       "  [68, 12, 29, 2, 3, 3, 3, 3, 5, 5],\n",
       "  [12, 29, 2, 3, 3, 3, 3, 5, 5, 5],\n",
       "  [29, 2, 3, 3, 3, 3, 5, 5, 5, 3],\n",
       "  [2, 3, 3, 3, 3, 5, 5, 5, 3, 38],\n",
       "  [3, 3, 3, 3, 5, 5, 5, 3, 38, 82],\n",
       "  [3, 3, 3, 5, 5, 5, 3, 38, 82, 81],\n",
       "  [3, 3, 5, 5, 5, 3, 38, 82, 81, 89],\n",
       "  [3, 5, 5, 5, 3, 38, 82, 81, 89, 72],\n",
       "  [5, 5, 5, 3, 38, 82, 81, 89, 72, 85],\n",
       "  [5, 5, 3, 38, 82, 81, 89, 72, 85, 87],\n",
       "  [5, 3, 38, 82, 81, 89, 72, 85, 87, 3],\n",
       "  [3, 38, 82, 81, 89, 72, 85, 87, 3, 68],\n",
       "  [38, 82, 81, 89, 72, 85, 87, 3, 68, 3],\n",
       "  [82, 81, 89, 72, 85, 87, 3, 68, 3, 86],\n",
       "  [81, 89, 72, 85, 87, 3, 68, 3, 86, 70],\n",
       "  [89, 72, 85, 87, 3, 68, 3, 86, 70, 75],\n",
       "  [72, 85, 87, 3, 68, 3, 86, 70, 75, 72],\n",
       "  [85, 87, 3, 68, 3, 86, 70, 75, 72, 80],\n",
       "  [87, 3, 68, 3, 86, 70, 75, 72, 80, 68],\n",
       "  [3, 68, 3, 86, 70, 75, 72, 80, 68, 3],\n",
       "  [68, 3, 86, 70, 75, 72, 80, 68, 3, 73],\n",
       "  [3, 86, 70, 75, 72, 80, 68, 3, 73, 85],\n",
       "  [86, 70, 75, 72, 80, 68, 3, 73, 85, 82],\n",
       "  [70, 75, 72, 80, 68, 3, 73, 85, 82, 80],\n",
       "  [75, 72, 80, 68, 3, 73, 85, 82, 80, 3],\n",
       "  [72, 80, 68, 3, 73, 85, 82, 80, 3, 54],\n",
       "  [80, 68, 3, 73, 85, 82, 80, 3, 54, 83],\n",
       "  [68, 3, 73, 85, 82, 80, 3, 54, 83, 68],\n",
       "  [3, 73, 85, 82, 80, 3, 54, 83, 68, 85],\n",
       "  [73, 85, 82, 80, 3, 54, 83, 68, 85, 78],\n",
       "  [85, 82, 80, 3, 54, 83, 68, 85, 78, 3],\n",
       "  [82, 80, 3, 54, 83, 68, 85, 78, 3, 87],\n",
       "  [80, 3, 54, 83, 68, 85, 78, 3, 87, 82],\n",
       "  [3, 54, 83, 68, 85, 78, 3, 87, 82, 3],\n",
       "  [54, 83, 68, 85, 78, 3, 87, 82, 3, 36],\n",
       "  [83, 68, 85, 78, 3, 87, 82, 3, 36, 85],\n",
       "  [68, 85, 78, 3, 87, 82, 3, 36, 85, 85],\n",
       "  [85, 78, 3, 87, 82, 3, 36, 85, 85, 82],\n",
       "  [78, 3, 87, 82, 3, 36, 85, 85, 82, 90],\n",
       "  [3, 87, 82, 3, 36, 85, 85, 82, 90, 2],\n",
       "  [87, 82, 3, 36, 85, 85, 82, 90, 2, 3],\n",
       "  [82, 3, 36, 85, 85, 82, 90, 2, 3, 3],\n",
       "  [3, 36, 85, 85, 82, 90, 2, 3, 3, 3],\n",
       "  [36, 85, 85, 82, 90, 2, 3, 3, 3, 3],\n",
       "  [85, 85, 82, 90, 2, 3, 3, 3, 3, 5],\n",
       "  [85, 82, 90, 2, 3, 3, 3, 3, 5, 5],\n",
       "  [82, 90, 2, 3, 3, 3, 3, 5, 5, 5],\n",
       "  [90, 2, 3, 3, 3, 3, 5, 5, 5, 2],\n",
       "  [2, 3, 3, 3, 3, 5, 5, 5, 2, 3],\n",
       "  [3, 3, 3, 3, 5, 5, 5, 2, 3, 3],\n",
       "  [3, 3, 3, 5, 5, 5, 2, 3, 3, 3],\n",
       "  [3, 3, 5, 5, 5, 2, 3, 3, 3, 3],\n",
       "  [3, 5, 5, 5, 2, 3, 3, 3, 3, 76],\n",
       "  [5, 5, 5, 2, 3, 3, 3, 3, 76, 80],\n",
       "  [5, 5, 2, 3, 3, 3, 3, 76, 80, 83],\n",
       "  [5, 2, 3, 3, 3, 3, 76, 80, 83, 82],\n",
       "  [2, 3, 3, 3, 3, 76, 80, 83, 82, 85],\n",
       "  [3, 3, 3, 3, 76, 80, 83, 82, 85, 87],\n",
       "  [3, 3, 3, 76, 80, 83, 82, 85, 87, 3],\n",
       "  [3, 3, 76, 80, 83, 82, 85, 87, 3, 83],\n",
       "  [3, 76, 80, 83, 82, 85, 87, 3, 83, 92],\n",
       "  [76, 80, 83, 82, 85, 87, 3, 83, 92, 68],\n",
       "  [80, 83, 82, 85, 87, 3, 83, 92, 68, 85],\n",
       "  [83, 82, 85, 87, 3, 83, 92, 68, 85, 85],\n",
       "  [82, 85, 87, 3, 83, 92, 68, 85, 85, 82],\n",
       "  [85, 87, 3, 83, 92, 68, 85, 85, 82, 90],\n",
       "  [87, 3, 83, 92, 68, 85, 85, 82, 90, 3],\n",
       "  [3, 83, 92, 68, 85, 85, 82, 90, 3, 68],\n",
       "  [83, 92, 68, 85, 85, 82, 90, 3, 68, 86],\n",
       "  [92, 68, 85, 85, 82, 90, 3, 68, 86, 3],\n",
       "  [68, 85, 85, 82, 90, 3, 68, 86, 3, 83],\n",
       "  [85, 85, 82, 90, 3, 68, 86, 3, 83, 68],\n",
       "  [85, 82, 90, 3, 68, 86, 3, 83, 68, 2],\n",
       "  [82, 90, 3, 68, 86, 3, 83, 68, 2, 3],\n",
       "  [90, 3, 68, 86, 3, 83, 68, 2, 3, 3],\n",
       "  [3, 68, 86, 3, 83, 68, 2, 3, 3, 3],\n",
       "  [68, 86, 3, 83, 68, 2, 3, 3, 3, 3],\n",
       "  [86, 3, 83, 68, 2, 3, 3, 3, 3, 73],\n",
       "  [3, 83, 68, 2, 3, 3, 3, 3, 73, 76],\n",
       "  [83, 68, 2, 3, 3, 3, 3, 73, 76, 72],\n",
       "  [68, 2, 3, 3, 3, 3, 73, 76, 72, 79],\n",
       "  [2, 3, 3, 3, 3, 73, 76, 72, 79, 71],\n",
       "  [3, 3, 3, 3, 73, 76, 72, 79, 71, 86],\n",
       "  [3, 3, 3, 73, 76, 72, 79, 71, 86, 3],\n",
       "  [3, 3, 73, 76, 72, 79, 71, 86, 3, 32],\n",
       "  [3, 73, 76, 72, 79, 71, 86, 3, 32, 3],\n",
       "  [73, 76, 72, 79, 71, 86, 3, 32, 3, 62],\n",
       "  [76, 72, 79, 71, 86, 3, 32, 3, 62, 83],\n",
       "  [72, 79, 71, 86, 3, 32, 3, 62, 83, 68],\n",
       "  [79, 71, 86, 3, 32, 3, 62, 83, 68, 17],\n",
       "  [71, 86, 3, 32, 3, 62, 83, 68, 17, 73],\n",
       "  [86, 3, 32, 3, 62, 83, 68, 17, 73, 76],\n",
       "  [3, 32, 3, 62, 83, 68, 17, 73, 76, 72],\n",
       "  [32, 3, 62, 83, 68, 17, 73, 76, 72, 79],\n",
       "  [3, 62, 83, 68, 17, 73, 76, 72, 79, 71],\n",
       "  [62, 83, 68, 17, 73, 76, 72, 79, 71, 11],\n",
       "  [83, 68, 17, 73, 76, 72, 79, 71, 11, 73],\n",
       "  [68, 17, 73, 76, 72, 79, 71, 11, 73, 76],\n",
       "  [17, 73, 76, 72, 79, 71, 11, 73, 76, 72],\n",
       "  [73, 76, 72, 79, 71, 11, 73, 76, 72, 79],\n",
       "  [76, 72, 79, 71, 11, 73, 76, 72, 79, 71],\n",
       "  [72, 79, 71, 11, 73, 76, 72, 79, 71, 17],\n",
       "  [79, 71, 11, 73, 76, 72, 79, 71, 17, 81],\n",
       "  [71, 11, 73, 76, 72, 79, 71, 17, 81, 68],\n",
       "  [11, 73, 76, 72, 79, 71, 17, 81, 68, 80],\n",
       "  [73, 76, 72, 79, 71, 17, 81, 68, 80, 72],\n",
       "  [76, 72, 79, 71, 17, 81, 68, 80, 72, 15],\n",
       "  [72, 79, 71, 17, 81, 68, 80, 72, 15, 3],\n",
       "  [79, 71, 17, 81, 68, 80, 72, 15, 3, 87],\n",
       "  [71, 17, 81, 68, 80, 72, 15, 3, 87, 82],\n",
       "  [17, 81, 68, 80, 72, 15, 3, 87, 82, 66],\n",
       "  [81, 68, 80, 72, 15, 3, 87, 82, 66, 68],\n",
       "  [68, 80, 72, 15, 3, 87, 82, 66, 68, 85],\n",
       "  [80, 72, 15, 3, 87, 82, 66, 68, 85, 85],\n",
       "  [72, 15, 3, 87, 82, 66, 68, 85, 85, 82],\n",
       "  [15, 3, 87, 82, 66, 68, 85, 85, 82, 90],\n",
       "  [3, 87, 82, 66, 68, 85, 85, 82, 90, 66],\n",
       "  [87, 82, 66, 68, 85, 85, 82, 90, 66, 87],\n",
       "  [82, 66, 68, 85, 85, 82, 90, 66, 87, 92],\n",
       "  [66, 68, 85, 85, 82, 90, 66, 87, 92, 83],\n",
       "  [68, 85, 85, 82, 90, 66, 87, 92, 83, 72],\n",
       "  [85, 85, 82, 90, 66, 87, 92, 83, 72, 11],\n",
       "  [85, 82, 90, 66, 87, 92, 83, 72, 11, 73],\n",
       "  [82, 90, 66, 87, 92, 83, 72, 11, 73, 76],\n",
       "  [90, 66, 87, 92, 83, 72, 11, 73, 76, 72],\n",
       "  [66, 87, 92, 83, 72, 11, 73, 76, 72, 79],\n",
       "  [87, 92, 83, 72, 11, 73, 76, 72, 79, 71],\n",
       "  [92, 83, 72, 11, 73, 76, 72, 79, 71, 17],\n",
       "  [83, 72, 11, 73, 76, 72, 79, 71, 17, 71],\n",
       "  [72, 11, 73, 76, 72, 79, 71, 17, 71, 68],\n",
       "  [11, 73, 76, 72, 79, 71, 17, 71, 68, 87],\n",
       "  [73, 76, 72, 79, 71, 17, 71, 68, 87, 68],\n",
       "  [76, 72, 79, 71, 17, 71, 68, 87, 68, 55],\n",
       "  [72, 79, 71, 17, 71, 68, 87, 68, 55, 92],\n",
       "  [79, 71, 17, 71, 68, 87, 68, 55, 92, 83],\n",
       "  [71, 17, 71, 68, 87, 68, 55, 92, 83, 72],\n",
       "  [17, 71, 68, 87, 68, 55, 92, 83, 72, 12],\n",
       "  [71, 68, 87, 68, 55, 92, 83, 72, 12, 15],\n",
       "  [68, 87, 68, 55, 92, 83, 72, 12, 15, 3],\n",
       "  [87, 68, 55, 92, 83, 72, 12, 15, 3, 81],\n",
       "  [68, 55, 92, 83, 72, 12, 15, 3, 81, 88],\n",
       "  [55, 92, 83, 72, 12, 15, 3, 81, 88, 79],\n",
       "  [92, 83, 72, 12, 15, 3, 81, 88, 79, 79],\n",
       "  [83, 72, 12, 15, 3, 81, 88, 79, 79, 68],\n",
       "  [72, 12, 15, 3, 81, 88, 79, 79, 68, 69],\n",
       "  [12, 15, 3, 81, 88, 79, 79, 68, 69, 79],\n",
       "  [15, 3, 81, 88, 79, 79, 68, 69, 79, 72],\n",
       "  [3, 81, 88, 79, 79, 68, 69, 79, 72, 32],\n",
       "  [81, 88, 79, 79, 68, 69, 79, 72, 32, 73],\n",
       "  [88, 79, 79, 68, 69, 79, 72, 32, 73, 76],\n",
       "  [79, 79, 68, 69, 79, 72, 32, 73, 76, 72],\n",
       "  [79, 68, 69, 79, 72, 32, 73, 76, 72, 79],\n",
       "  [68, 69, 79, 72, 32, 73, 76, 72, 79, 71],\n",
       "  [69, 79, 72, 32, 73, 76, 72, 79, 71, 17],\n",
       "  [79, 72, 32, 73, 76, 72, 79, 71, 17, 81],\n",
       "  [72, 32, 73, 76, 72, 79, 71, 17, 81, 88],\n",
       "  [32, 73, 76, 72, 79, 71, 17, 81, 88, 79],\n",
       "  [73, 76, 72, 79, 71, 17, 81, 88, 79, 79],\n",
       "  [76, 72, 79, 71, 17, 81, 88, 79, 79, 68],\n",
       "  [72, 79, 71, 17, 81, 88, 79, 79, 68, 69],\n",
       "  [79, 71, 17, 81, 88, 79, 79, 68, 69, 79],\n",
       "  [71, 17, 81, 88, 79, 79, 68, 69, 79, 72],\n",
       "  [17, 81, 88, 79, 79, 68, 69, 79, 72, 12],\n",
       "  [81, 88, 79, 79, 68, 69, 79, 72, 12, 2],\n",
       "  [88, 79, 79, 68, 69, 79, 72, 12, 2, 3],\n",
       "  [79, 79, 68, 69, 79, 72, 12, 2, 3, 3],\n",
       "  [79, 68, 69, 79, 72, 12, 2, 3, 3, 3],\n",
       "  [68, 69, 79, 72, 12, 2, 3, 3, 3, 3],\n",
       "  [69, 79, 72, 12, 2, 3, 3, 3, 3, 3],\n",
       "  [79, 72, 12, 2, 3, 3, 3, 3, 3, 3],\n",
       "  [72, 12, 2, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [12, 2, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [2, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 3, 73],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 3, 73, 82],\n",
       "  [3, 3, 3, 3, 3, 3, 3, 73, 82, 85],\n",
       "  [3, 3, 3, 3, 3, 3, 73, 82, 85, 3],\n",
       "  [3, 3, 3, 3, 3, 73, 82, 85, 3, 73],\n",
       "  [3, 3, 3, 3, 73, 82, 85, 3, 73, 76],\n",
       "  [3, 3, 3, 73, 82, 85, 3, 73, 76, 72],\n",
       "  [3, 3, 73, 82, 85, 3, 73, 76, 72, 79],\n",
       "  [3, 73, 82, 85, 3, 73, 76, 72, 79, 71],\n",
       "  [73, 82, 85, 3, 73, 76, 72, 79, 71, 3],\n",
       "  [82, 85, 3, 73, 76, 72, 79, 71, 3, 76],\n",
       "  [85, 3, 73, 76, 72, 79, 71, 3, 76, 81],\n",
       "  [3, 73, 76, 72, 79, 71, 3, 76, 81, 3],\n",
       "  [73, 76, 72, 79, 71, 3, 76, 81, 3, 86],\n",
       "  [76, 72, 79, 71, 3, 76, 81, 3, 86, 70],\n",
       "  [72, 79, 71, 3, 76, 81, 3, 86, 70, 75],\n",
       "  [79, 71, 3, 76, 81, 3, 86, 70, 75, 72],\n",
       "  [71, 3, 76, 81, 3, 86, 70, 75, 72, 80],\n",
       "  [3, 76, 81, 3, 86, 70, 75, 72, 80, 68],\n",
       "  [76, 81, 3, 86, 70, 75, 72, 80, 68, 64],\n",
       "  [81, 3, 86, 70, 75, 72, 80, 68, 64, 2],\n",
       "  [3, 86, 70, 75, 72, 80, 68, 64, 2, 3],\n",
       "  [86, 70, 75, 72, 80, 68, 64, 2, 3, 3],\n",
       "  [70, 75, 72, 80, 68, 64, 2, 3, 3, 3],\n",
       "  [75, 72, 80, 68, 64, 2, 3, 3, 3, 3],\n",
       "  [72, 80, 68, 64, 2, 3, 3, 3, 3, 85],\n",
       "  [80, 68, 64, 2, 3, 3, 3, 3, 85, 72],\n",
       "  [68, 64, 2, 3, 3, 3, 3, 85, 72, 87],\n",
       "  [64, 2, 3, 3, 3, 3, 85, 72, 87, 88],\n",
       "  [2, 3, 3, 3, 3, 85, 72, 87, 88, 85],\n",
       "  [3, 3, 3, 3, 85, 72, 87, 88, 85, 81],\n",
       "  [3, 3, 3, 85, 72, 87, 88, 85, 81, 3],\n",
       "  [3, 3, 85, 72, 87, 88, 85, 81, 3, 83],\n",
       "  [3, 85, 72, 87, 88, 85, 81, 3, 83, 68],\n",
       "  [85, 72, 87, 88, 85, 81, 3, 83, 68, 17],\n",
       "  [72, 87, 88, 85, 81, 3, 83, 68, 17, 86],\n",
       "  [87, 88, 85, 81, 3, 83, 68, 17, 86, 70],\n",
       "  [88, 85, 81, 3, 83, 68, 17, 86, 70, 75],\n",
       "  [85, 81, 3, 83, 68, 17, 86, 70, 75, 72],\n",
       "  [81, 3, 83, 68, 17, 86, 70, 75, 72, 80],\n",
       "  [3, 83, 68, 17, 86, 70, 75, 72, 80, 68],\n",
       "  [83, 68, 17, 86, 70, 75, 72, 80, 68, 11],\n",
       "  [68, 17, 86, 70, 75, 72, 80, 68, 11, 73],\n",
       "  [17, 86, 70, 75, 72, 80, 68, 11, 73, 76],\n",
       "  [86, 70, 75, 72, 80, 68, 11, 73, 76, 72],\n",
       "  [70, 75, 72, 80, 68, 11, 73, 76, 72, 79],\n",
       "  [75, 72, 80, 68, 11, 73, 76, 72, 79, 71],\n",
       "  [72, 80, 68, 11, 73, 76, 72, 79, 71, 86],\n",
       "  [80, 68, 11, 73, 76, 72, 79, 71, 86, 12]],\n",
       " 'labels': [71,\n",
       "  72,\n",
       "  73,\n",
       "  3,\n",
       "  87,\n",
       "  82,\n",
       "  66,\n",
       "  68,\n",
       "  85,\n",
       "  85,\n",
       "  82,\n",
       "  90,\n",
       "  66,\n",
       "  86,\n",
       "  70,\n",
       "  75,\n",
       "  72,\n",
       "  80,\n",
       "  68,\n",
       "  11,\n",
       "  86,\n",
       "  70,\n",
       "  75,\n",
       "  72,\n",
       "  80,\n",
       "  68,\n",
       "  12,\n",
       "  29,\n",
       "  2,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  5,\n",
       "  5,\n",
       "  5,\n",
       "  3,\n",
       "  38,\n",
       "  82,\n",
       "  81,\n",
       "  89,\n",
       "  72,\n",
       "  85,\n",
       "  87,\n",
       "  3,\n",
       "  68,\n",
       "  3,\n",
       "  86,\n",
       "  70,\n",
       "  75,\n",
       "  72,\n",
       "  80,\n",
       "  68,\n",
       "  3,\n",
       "  73,\n",
       "  85,\n",
       "  82,\n",
       "  80,\n",
       "  3,\n",
       "  54,\n",
       "  83,\n",
       "  68,\n",
       "  85,\n",
       "  78,\n",
       "  3,\n",
       "  87,\n",
       "  82,\n",
       "  3,\n",
       "  36,\n",
       "  85,\n",
       "  85,\n",
       "  82,\n",
       "  90,\n",
       "  2,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  5,\n",
       "  5,\n",
       "  5,\n",
       "  2,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  76,\n",
       "  80,\n",
       "  83,\n",
       "  82,\n",
       "  85,\n",
       "  87,\n",
       "  3,\n",
       "  83,\n",
       "  92,\n",
       "  68,\n",
       "  85,\n",
       "  85,\n",
       "  82,\n",
       "  90,\n",
       "  3,\n",
       "  68,\n",
       "  86,\n",
       "  3,\n",
       "  83,\n",
       "  68,\n",
       "  2,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  86,\n",
       "  3,\n",
       "  32,\n",
       "  3,\n",
       "  62,\n",
       "  83,\n",
       "  68,\n",
       "  17,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  11,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  17,\n",
       "  81,\n",
       "  68,\n",
       "  80,\n",
       "  72,\n",
       "  15,\n",
       "  3,\n",
       "  87,\n",
       "  82,\n",
       "  66,\n",
       "  68,\n",
       "  85,\n",
       "  85,\n",
       "  82,\n",
       "  90,\n",
       "  66,\n",
       "  87,\n",
       "  92,\n",
       "  83,\n",
       "  72,\n",
       "  11,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  17,\n",
       "  71,\n",
       "  68,\n",
       "  87,\n",
       "  68,\n",
       "  55,\n",
       "  92,\n",
       "  83,\n",
       "  72,\n",
       "  12,\n",
       "  15,\n",
       "  3,\n",
       "  81,\n",
       "  88,\n",
       "  79,\n",
       "  79,\n",
       "  68,\n",
       "  69,\n",
       "  79,\n",
       "  72,\n",
       "  32,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  17,\n",
       "  81,\n",
       "  88,\n",
       "  79,\n",
       "  79,\n",
       "  68,\n",
       "  69,\n",
       "  79,\n",
       "  72,\n",
       "  12,\n",
       "  2,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  73,\n",
       "  82,\n",
       "  85,\n",
       "  3,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  3,\n",
       "  76,\n",
       "  81,\n",
       "  3,\n",
       "  86,\n",
       "  70,\n",
       "  75,\n",
       "  72,\n",
       "  80,\n",
       "  68,\n",
       "  64,\n",
       "  2,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  3,\n",
       "  85,\n",
       "  72,\n",
       "  87,\n",
       "  88,\n",
       "  85,\n",
       "  81,\n",
       "  3,\n",
       "  83,\n",
       "  68,\n",
       "  17,\n",
       "  86,\n",
       "  70,\n",
       "  75,\n",
       "  72,\n",
       "  80,\n",
       "  68,\n",
       "  11,\n",
       "  73,\n",
       "  76,\n",
       "  72,\n",
       "  79,\n",
       "  71,\n",
       "  86,\n",
       "  12,\n",
       "  1]}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def process(data):\n",
    "    '''\n",
    "    在datasets的map里使用，将文本转换成训练数据\n",
    "    '''\n",
    "    text = data['whole_func_string']\n",
    "    # 如果是普通的map操作，传入的值是字符串\n",
    "    if isinstance(text, str):\n",
    "        inputs, labels = autoregressive_trans(text, tok)\n",
    "        return {'inputs': inputs, 'labels': labels}\n",
    "    # 如果是map操作里面batched=True，传入的值是字符串列表\n",
    "    inputs, labels = [], []\n",
    "    for i in text:\n",
    "        i, l = autoregressive_trans(i, tok)\n",
    "        inputs += i\n",
    "        labels += l\n",
    "    return {'inputs': inputs, 'labels': labels}\n",
    "\n",
    "process(datasets[8:9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "U0Ojs8kbsiSS",
    "outputId": "0e571fc8-2432-4655-d276-538a79031000"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([645401, 10]), torch.Size([645401]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将数据分为训练集和测试集\n",
    "tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n",
    "# 将文本转换为训练数据，里面包含inputs和labels\n",
    "tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n",
    "tokenized.set_format(type='torch', device=device)\n",
    "\n",
    "tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "SdfNybxcsiSS",
    "outputId": "d8276b00-2799-4a1c-d0b5-41c9e515803f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'inputs': tensor([[38, 85, 72,  ..., 12,  2,  3],\n",
       "         [ 3, 76, 73,  ..., 86, 87, 66],\n",
       "         [80, 72, 15,  ..., 75, 15,  3],\n",
       "         ...,\n",
       "         [75, 76, 86,  ..., 79,  3, 68],\n",
       "         [87, 68, 81,  ..., 82, 90, 15],\n",
       "         [ 3,  3,  3,  ..., 68, 85, 68]], device='cuda:0'),\n",
       " 'labels': tensor([ 3, 86, 86, 90, 72, 87,  3, 81, 76, 29, 74, 87, 87, 87, 71,  3,  3,  3,\n",
       "         15,  3, 68,  3, 87, 75, 86, 17, 75, 13,  3, 80,  3, 83, 86, 68,  3, 83,\n",
       "          3, 81, 53, 83, 11, 66,  3, 85, 85, 82, 76, 85, 73, 14,  3, 75,  3,  3,\n",
       "         88,  3,  3,  3, 79, 82,  3, 50, 58, 11, 17,  3, 86, 93,  3, 87,  3, 74,\n",
       "         78, 78, 80, 54,  3,  2, 81, 80,  3, 81,  3,  3, 74,  3, 76, 82, 82,  3,\n",
       "         60, 12,  5, 73, 75, 72, 87,  3, 66, 73, 69, 85,  3, 76,  3,  3,  3, 86,\n",
       "         87, 72,  2,  3, 70, 85,  3,  3, 72, 32,  3,  3, 68,  3, 79, 72, 88,  3,\n",
       "         70, 76, 11,  3, 80, 72,  3, 87, 71,  3, 76, 73, 11, 82,  8, 84,  3, 72,\n",
       "          3, 76,  3, 49, 79, 87,  3,  3, 73, 12, 79, 90, 82, 83,  3, 71, 86, 91,\n",
       "          3, 75, 87, 83, 75, 72, 87, 76,  3,  3, 71, 92, 17, 87, 10, 87,  6, 70,\n",
       "         10,  3, 88,  3, 81, 86, 12, 72,  3,  3, 81, 72, 19, 68, 76, 44, 73, 82,\n",
       "         70, 82, 68,  3, 88, 86,  3,  5, 67, 72, 79, 70, 81,  3,  3, 75, 87,  3,\n",
       "          3, 76,  6,  3, 72, 70, 87,  3,  3, 87, 71,  3, 68,  2, 83, 85, 71, 87,\n",
       "         72, 33, 15, 87, 74,  3,  3, 29, 68,  3, 75, 85, 80,  3, 70, 72, 70, 82,\n",
       "         73, 82, 70, 86, 15,  3, 68,  5, 86, 82,  3,  3,  3,  3, 87,  3, 17, 76,\n",
       "         75, 87, 93,  3,  2, 81, 90,  3, 91, 72, 19,  3,  3, 69, 49,  3, 71, 81,\n",
       "         70,  5, 33,  3, 87,  3, 76, 20, 29,  3, 71, 68, 88, 86, 86, 17, 77, 82,\n",
       "         88,  3, 73,  3, 76,  3,  3, 81, 79, 72, 73, 82,  3, 82, 83, 72, 54, 20,\n",
       "         72, 72, 81,  2,  3, 80, 87, 85, 83, 80, 87, 72, 72, 76, 82,  3, 27, 81,\n",
       "          3,  3,  3, 86, 80, 66, 17, 76, 71, 80, 87, 17, 79, 83, 76, 83,  3,  3,\n",
       "         72, 76,  3, 87, 87, 46, 11, 68, 72,  3,  3, 70,  3, 81,  3,  2,  3, 68,\n",
       "          3, 72,  3, 68,  3, 72, 73,  3, 82, 83,  3, 32, 68,  2, 68, 73, 76, 11,\n",
       "         68, 68,  3,  3,  2,  3, 72, 17,  3, 55, 51,  3,  3, 87, 87,  3, 72, 87,\n",
       "          3,  3, 15, 89, 82,  3, 73, 81,  3, 85,  3, 74, 17, 75, 85,  3, 80, 79,\n",
       "          3, 85,  3,  3, 15, 88,  3, 12, 77, 83, 87, 86, 74, 72, 91,  3,  3, 79,\n",
       "         71, 29, 72,  3, 72,  3, 68, 81, 73,  3, 71, 74, 86,  3, 87,  3, 79, 17,\n",
       "         85, 76,  3, 12, 85, 72, 72,  5,  3,  3, 90,  3, 17,  3,  3, 72, 86, 80,\n",
       "          3, 82, 86, 68, 72,  3, 69, 71, 86, 85, 85,  3, 77, 10, 71, 76,  3,  3,\n",
       "          3, 85, 87,  5, 76, 81,  3,  3, 90, 68, 71, 81, 85, 80,  3, 79, 68, 17,\n",
       "         79,  3, 86,  3, 70, 73,  1,  3,  3,  3,  3, 87, 86,  3, 86, 68,  3, 87,\n",
       "         72,  3,  2, 86, 72, 87, 79, 71, 81, 15, 86, 85, 87, 78, 86, 74,  3, 66,\n",
       "         79, 82, 76, 80, 82,  2, 79, 70, 74, 51, 68, 66,  3,  3, 79,  3, 71,  2,\n",
       "         86, 85,  3, 68,  3,  5, 39, 71, 72, 81, 87,  3,  3,  3,  3,  3,  3, 72,\n",
       "          3, 72, 88, 81,  3, 85, 82, 68, 39,  3, 72, 68,  2,  3, 68, 70, 80, 69,\n",
       "         72,  2, 49, 82, 82, 81, 90,  3, 76,  2, 66, 72,  3,  3,  3,  3, 88, 72,\n",
       "          3, 19,  3, 72, 81,  3, 76, 78,  3,  3, 76, 76, 89, 81, 80, 87,  3, 81,\n",
       "         68, 81, 79,  2, 11, 71,  3, 81, 72, 49, 87, 72, 76,  5, 73, 68,  3,  3,\n",
       "         79, 72, 71, 21,  3, 85, 17,  3,  3,  3, 79, 76,  3,  3, 68, 82,  3, 87,\n",
       "         81,  3, 87, 72,  3, 87, 21, 44,  3, 87, 68, 81,  3, 68, 82, 87,  3, 70,\n",
       "          3, 12, 72, 87, 85,  2,  3,  3, 75, 82, 50,  3, 79, 85, 68, 64,  3, 62,\n",
       "          8, 11, 79, 72,  3, 76, 85, 15, 33, 32, 68,  3,  5,  3, 76, 21, 88, 79,\n",
       "         49, 68,  3, 88, 51,  3,  3, 80, 68, 79,  3,  3,  2, 32,  3, 12,  2, 68,\n",
       "          3, 81, 17, 87, 71,  3,  3, 76, 76, 92,  3, 88,  3,  3, 49, 80,  3, 85,\n",
       "          3,  3, 19, 79, 49, 87,  3,  3, 76, 80,  3, 68,  2, 71, 11,  3, 13,  3,\n",
       "         72, 29, 76, 86, 82,  3, 29, 32,  3, 10, 81,  3,  3,  3, 87, 85, 17, 85,\n",
       "         10, 11, 85, 78,  3, 17, 78, 70, 21,  3, 83,  3,  3, 90,  3, 85,  3, 82,\n",
       "         71, 88,  3, 63, 80, 89, 54, 69, 11, 15,  3, 81,  3, 69, 72, 68,  3, 88,\n",
       "         42, 68,  3, 38, 19, 12, 41, 76, 87, 11,  3, 87,  3, 76, 79, 72, 82,  3,\n",
       "         70,  2,  2, 67,  3, 15, 80, 87,  2, 54,  2, 11, 11,  3, 70, 87, 87, 70,\n",
       "         75, 12, 11, 82, 87,  2, 81, 70, 70, 21, 82, 81,  3, 72, 86, 72,  2, 47,\n",
       "         87,  3, 89, 78, 29, 85, 68, 12, 86, 82, 49,  3,  3, 81,  3, 73, 76, 83,\n",
       "          3, 68, 79,  3, 72,  3,  3,  3, 72,  3, 83, 63,  3, 86, 82, 73, 71, 72,\n",
       "         54, 72,  3, 17, 92, 73,  3, 86,  3, 82, 15,  3, 72, 74,  3, 72,  3, 17,\n",
       "          3, 72, 81, 75, 72,  3,  3,  3, 87, 37, 75, 73,  3,  5, 86,  2,  5, 81,\n",
       "         68,  3, 15,  2, 85, 85,  3,  3,  3, 72,  3, 80, 72, 68,  3, 85, 11, 83,\n",
       "         62, 72,  3, 81, 86,  2, 75, 79,  3, 80], device='cuda:0')}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 构建数据读取器\n",
    "train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)\n",
    "# 获取一个批量的数据\n",
    "next(iter(test_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "WVXSXA6vsiSS"
   },
   "outputs": [],
   "source": [
    "class CharMLP(nn.Module):\n",
    "\n",
    "    def __init__(self, vs):\n",
    "        '''\n",
    "        根据文本背景预测下一个字母是什么\n",
    "        参数\n",
    "        ----\n",
    "        vs ：int，字典大小\n",
    "        '''\n",
    "        super().__init__()\n",
    "        # 文字嵌入层\n",
    "        self.embedding = nn.Embedding(vs, 30)\n",
    "        self.hidden1 = nn.Linear(300, 200)\n",
    "        self.hidden2 = nn.Linear(200, 100)\n",
    "        self.out = nn.Linear(100, vs)\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.LongTensor，背景文本，其中的元素表示相应位置的字母在字典中的位置\n",
    "        返回\n",
    "        ----\n",
    "        h ：torch.FloatTensor，预测结果的logits\n",
    "        '''\n",
    "        # 因为背景文本的长度（context_length）等于10，\n",
    "        # 所以x的形状是(B, 10)，B表示批量数据的大小\n",
    "        B = x.shape[0]               # (B,  10)\n",
    "        emb = self.embedding(x)      # (B,  10, 30)\n",
    "        h = emb.view(B, -1)          # (B, 300)\n",
    "        h = F.relu(self.hidden1(h))  # (B, 200)\n",
    "        h = F.relu(self.hidden2(h))  # (B, 100)\n",
    "        h = self.out(h)              # (B,  vs)\n",
    "        return h\n",
    "\n",
    "model = CharMLP(len(tok.char2ind)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "82S3F9IXsiSS"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def generate(model, context, max_new_tokens=300):\n",
    "    '''\n",
    "    利用模型生成文本（反复使用模型进行预测）\n",
    "    参数\n",
    "    ----\n",
    "    model ：CharMLP，生成文本的模型\n",
    "    context ：torch.LongTensor，背景文本，形状为(1, 10)\n",
    "    max_new_tokens ：int，生成文本的最大长度\n",
    "    返回\n",
    "    ----\n",
    "    out ：list[int]，生成的文本\n",
    "    '''\n",
    "    out = []\n",
    "    # 将模型切换至评估模式\n",
    "    model.eval()\n",
    "    for _ in range(max_new_tokens):\n",
    "        logits = model(context)\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        # 根据模型预测的概率，得到最终的预测结果（下一个字母）\n",
    "        # 这一步运算有一定随机性\n",
    "        ix = torch.multinomial(probs, num_samples=1)\n",
    "        # 利用模型的预测结果更新文本背景\n",
    "        context = torch.cat((context[:, 1:], ix), dim=1)\n",
    "        out.append(ix.item())\n",
    "        if ix.item() == tok.end_ind:\n",
    "            break\n",
    "    # 将模型切换至训练模式\n",
    "    model.train()\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "4zSHFccTsiST",
    "outputId": "e0a6ca3f-fe74-49c6-caf1-0080eef4b158"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ")YN'E.ne'!XOzAYD\n",
      "F{tvö290&^#>P8(MZzJP<BJe@L9hJb`Q:*P2i;r@dfVR#L/sw pxS2!xLl4Y`(&V}\\[Vl%A!'zq<|e|>\n"
     ]
    }
   ],
   "source": [
    "# 使用模型来生成文本\n",
    "context = torch.zeros((1, 10), dtype=torch.long, device=device)\n",
    "print(''.join(tok.decode(generate(model, context))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "bo3UzjmXsiST",
    "outputId": "37337a6e-a3c0-40aa-84cb-6b88ec257da7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': 4.5956830978393555, 'test': 4.594418525695801}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def estimate_loss(model):\n",
    "    re = {}\n",
    "    # 将模型切换至评估模式\n",
    "    model.eval()\n",
    "    re['train'] = _loss(model, train_loader)\n",
    "    re['test'] = _loss(model, test_loader)\n",
    "    # 将模型切换至训练模式\n",
    "    model.train()\n",
    "    return re\n",
    "\n",
    "@torch.no_grad()\n",
    "def _loss(model, data_loader):\n",
    "    \"\"\"\n",
    "    计算模型在不同数据集下面的评估指标\n",
    "    \"\"\"\n",
    "    loss = []\n",
    "    data_iter= iter(data_loader)\n",
    "    # 随机使用多个批量数据来预估模型效果\n",
    "    for k in range(eval_iters):\n",
    "        data = next(data_iter, None)\n",
    "        if data is None:\n",
    "            data_iter = iter(data_loader)\n",
    "            data = next(data_iter, None)\n",
    "        inputs, labels = data['inputs'], data['labels']\n",
    "        logits = model(inputs)\n",
    "        loss.append(F.cross_entropy(logits, labels).item())\n",
    "    return torch.tensor(loss).mean().item()\n",
    "\n",
    "estimate_loss(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "18nkKzWLsiST"
   },
   "outputs": [],
   "source": [
    "def train_mlp(model, optimizer, data_loader, epochs=10):\n",
    "    # 记录模型在训练集上的模型损失\n",
    "    lossi = []\n",
    "    for epoch in range(epochs):\n",
    "        for i, data in enumerate(data_loader, 0):\n",
    "            inputs, labels = data['inputs'], data['labels']\n",
    "            optimizer.zero_grad()\n",
    "            logits = model(inputs)\n",
    "            loss = F.cross_entropy(logits, labels)\n",
    "            lossi.append(loss.item())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        # 评估模型，并输出结果\n",
    "        stats = estimate_loss(model)\n",
    "        train_loss = f'train loss {stats[\"train\"]:.4f}'\n",
    "        test_loss = f'test loss {stats[\"test\"]:.4f}'\n",
    "        print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')\n",
    "    return lossi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "n_af5vmnsiST",
    "outputId": "80a2167d-ff62-4261-b7c8-4e63f2b7287f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch  0: train loss 1.3726, test loss 1.5097\n",
      "epoch  1: train loss 1.2598, test loss 1.4965\n",
      "epoch  2: train loss 1.1934, test loss 1.4247\n",
      "epoch  3: train loss 1.1630, test loss 1.4014\n",
      "epoch  4: train loss 1.1505, test loss 1.3658\n",
      "epoch  5: train loss 1.1539, test loss 1.3594\n",
      "epoch  6: train loss 1.0862, test loss 1.3975\n",
      "epoch  7: train loss 1.0872, test loss 1.3718\n",
      "epoch  8: train loss 1.0707, test loss 1.3832\n",
      "epoch  9: train loss 1.0845, test loss 1.3286\n"
     ]
    }
   ],
   "source": [
    "l = train_mlp(model, optim.Adam(model.parameters(), lr=learning_rate), train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 448
    },
    "id": "a4tA30XfiMFv",
    "outputId": "d8065926-15f2-4f0e-e3dd-2a5e3ef16768"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7de0931bf760>]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABL2UlEQVR4nO3dd3zTdf4H8FdGk86ki+5BoezSsqFspCw5BM+BiIKKKByoOE/0zvnTese5zlPQQ0VFxAl4yF5FoOxWWkahUGiBDtrSpDMd+f7+SPNt0qYL2n5p83o+HnnQfPNN+skXaF79jPdHJgiCACIiIiKJyKVuABEREdk3hhEiIiKSFMMIERERSYphhIiIiCTFMEJERESSYhghIiIiSTGMEBERkaQYRoiIiEhSSqkb0BRGoxFXr16Fm5sbZDKZ1M0hIiKiJhAEAYWFhQgICIBcXn//R7sII1evXkVwcLDUzSAiIqIbkJGRgaCgoHofbxdhxM3NDYDpzWg0GolbQ0RERE2h1+sRHBwsfo7Xp12EEfPQjEajYRghIiJqZxqbYsEJrERERCQphhEiIiKSFMMIERERSYphhIiIiCTFMEJERESSYhghIiIiSTGMEBERkaQYRoiIiEhSDCNEREQkKYYRIiIikhTDCBEREUmKYYSIiIgk1S42ymstn+9LQ0Z+Ce4bEoyeftyAj4iISAp23TPy24mrWHXgItLzSqRuChERkd2y6zAir97S2ChI3BAiIiI71qwwsnz5ckRGRkKj0UCj0SA6OhqbN2+u9/xVq1ZBJpNZ3RwdHW+60S3FHEYEgWmEiIhIKs2aMxIUFIR33nkH3bp1gyAI+OqrrzB9+nQkJCSgT58+Np+j0WiQkpIi3pdVB4Bbgbkp7BkhIiKSTrPCyLRp06zuv/XWW1i+fDkOHjxYbxiRyWTw8/O78Ra2InPPSBV7RoiIiCRzw3NGqqqqsHbtWhQXFyM6Orre84qKihAaGorg4GBMnz4dJ0+ebPS1DQYD9Hq91a01KOQcpiEiIpJas8NIUlISXF1doVarsWDBAqxbtw69e/e2eW6PHj3wxRdfYMOGDVi9ejWMRiOGDx+Oy5cvN/g9YmNjodVqxVtwcHBzm9kkNcM0DCNERERSkQnN7BYoLy9Heno6dDodfvrpJ6xcuRJxcXH1BhJLFRUV6NWrF2bNmoU333yz3vMMBgMMBoN4X6/XIzg4GDqdDhpNy9UDmfvFYcSdvYZ374nCXQODWux1iYiIyPT5rdVqG/38bnbRM5VKhfDwcADAwIEDceTIEXz44Yf49NNPG32ug4MD+vfvj9TU1AbPU6vVUKvVzW1as8nZM0JERCS5m64zYjQarXoxGlJVVYWkpCT4+/vf7LdtETV1RhhGiIiIpNKsnpGlS5diypQpCAkJQWFhIdasWYM9e/Zg69atAIA5c+YgMDAQsbGxAIA33ngDw4YNQ3h4OAoKCrBs2TJcunQJjz76aMu/kxsgY9EzIiIiyTUrjOTk5GDOnDnIzMyEVqtFZGQktm7digkTJgAA0tPTIZfXdLZcv34d8+fPR1ZWFjw8PDBw4EAcOHCgSfNL2oKiuqnsGSEiIpJOsyewSqGpE2Caa+HqY9icnIU3Z0TgwWGhLfa6RERE1PTPb+5NA9YZISIikpJdhxGxzggnjRAREUnGrsNITTl4iRtCRERkx+w6jLAcPBERkfTsOoywHDwREZH07DqMyFlnhIiISHJ2HkZMf7JnhIiISDp2Hkaqe0bYNUJERCQZuw4jLAdPREQkPbsOIywHT0REJD27DiOcwEpERCQ9hhGwzggREZGU7DqMsM4IERGR9Ow6jIjl4I0SN4SIiMiO2XUYYTl4IiIi6dl1GOEwDRERkfTsOoxwNQ0REZH07DyMmP5kzwgREZF07DyMmOeMSNwQIiIiO2bXYUQmrqZhGiEiIpKKXYcRhThnhGGEiIhIKnYdRmrmjEjbDiIiIntm32GEdUaIiIgkZ9dhhHVGiIiIpGfXYYTl4ImIiKRn52HE9CeHaYiIiKRj52GEq2mIiIikxjACrqYhIiKSkp2HEdOf7BkhIiKSjn2HETnLwRMREUnNrsMIy8ETERFJz67DCMvBExERSc+uwwjLwRMREUnPzsMIy8ETERFJza7DCMvBExERSc+uw4hYDp5ZhIiISDL2HUaq3z2HaYiIiKRj32GEq2mIiIgkxzACwMhde4mIiCTDMAL2jBAREUnJzsOI6U9mESIiIunYdRgRy8EzjRAREUnGrsOIQs5hGiIiIqnZdRhhOXgiIiLp2XkYYTl4IiIiqdl1GGE5eCIiIunZdRhhnREiIiLpMYyAPSNERERSsu8wUv3uGUaIiIikY99hROwZkbghREREdoxhBOwZISIikpKdhxHTn8wiRERE0rHrMCKWg+c4DRERkWTsOozIWWeEiIhIcnYdRsx70zCLEBERSceuwwgnsBIREUnPrsMIy8ETERFJr1lhZPny5YiMjIRGo4FGo0F0dDQ2b97c4HN+/PFH9OzZE46Ojujbty82bdp0Uw1uSawzQkREJL1mhZGgoCC88847OHbsGI4ePYrbbrsN06dPx8mTJ22ef+DAAcyaNQvz5s1DQkICZsyYgRkzZiA5OblFGn+zavamYRohIiKSikwQbm6MwtPTE8uWLcO8efPqPDZz5kwUFxdj48aN4rFhw4ahX79+WLFiRZO/h16vh1arhU6ng0ajuZnmWknNKUTMe3vh4eyAhFcmttjrEhERUdM/v294zkhVVRXWrl2L4uJiREdH2zwnPj4eMTExVscmTZqE+Pj4Bl/bYDBAr9db3VqDjMM0REREkmt2GElKSoKrqyvUajUWLFiAdevWoXfv3jbPzcrKgq+vr9UxX19fZGVlNfg9YmNjodVqxVtwcHBzm9kkXE1DREQkvWaHkR49eiAxMRGHDh3CwoULMXfuXJw6dapFG7V06VLodDrxlpGR0aKvb8Zy8ERERNJTNvcJKpUK4eHhAICBAwfiyJEj+PDDD/Hpp5/WOdfPzw/Z2dlWx7Kzs+Hn59fg91Cr1VCr1c1tWrPJWQ6eiIhIcjddZ8RoNMJgMNh8LDo6Gjt37rQ6tn379nrnmLQ11hkhIiKSXrN6RpYuXYopU6YgJCQEhYWFWLNmDfbs2YOtW7cCAObMmYPAwEDExsYCAJ566imMGTMG7777LqZOnYq1a9fi6NGj+Oyzz1r+ndwAloMnIiKSXrPCSE5ODubMmYPMzExotVpERkZi69atmDBhAgAgPT0dcnlNZ8vw4cOxZs0a/O1vf8NLL72Ebt26Yf369YiIiGjZd3GDOIGViIhIejddZ6QttFadkZzCMgx5ayfkMuBC7NQWe10iIiJqgzojHQHLwRMREUmPYaQaS8ITERFJw67DiHkCKwBUMowQERFJwq7DiIOiJoyw1ggREZE07DqMWPeMGCVsCRERkf2y6zCitFiGXFnFnhEiIiIp2HUYUchlYhVWzhkhIiKShl2HEQBQyrk/DRERkZTsPoyY541wzggREZE07D6MmOeNcM4IERGRNBhGFOaeEYYRIiIiKTCMcM4IERGRpOw+jJjnjFRUcc4IERGRFOw+jJjnjLBnhIiISBoMI5wzQkREJCm7DyMKzhkhIiKSlN2HEfME1krOGSEiIpKE3YcRhbnOCHtGiIiIJGH3YcRBwWEaIiIiKdl9GKkpB88wQkREJAW7DyOcM0JERCQtuw8j7BkhIiKSlt2HEQcFi54RERFJye7DCHtGiIiIpGX3YYRzRoiIiKTFMMI6I0RERJKy+zCiYJ0RIiIiSdl9GFFyzggREZGk7D6MKDhnhIiISFJ2H0YcOGeEiIhIUnYfRjhnhIiISFp2H0a4tJeIiEhadh9GWPSMiIhIWnYfRlgOnoiISFp2H0bYM0JERCQtuw8jnDNCREQkLbsPI+wZISIikpbdhxHOGSEiIpKW3YcR9owQERFJy+7DCOeMEBERScvuwwh7RoiIiKRl92FEyTkjREREkmIYYc8IERGRpOw+jJiHaSo4Z4SIiEgSdh9G1ErTJSivZBghIiKSgt2HEUcHBQCgrKJK4pYQERHZJ7sPI+aeEQN7RoiIiCTBMKI09YwwjBAREUnD7sOIo4PpEnCYhoiISBp2H0bYM0JERCQtuw8j5p4RA3tGiIiIJGH3YURtXk3DnhEiIiJJMIxY1BkRBFZhJSIiamt2H0bMdUYAzhshIiKSgt2HEXPPCAAYKhhGiIiI2prdhxGlXIbq7WlgqOQkViIiorZm92FEJpNZlIRnzwgREVFba1YYiY2NxeDBg+Hm5gYfHx/MmDEDKSkpDT5n1apVkMlkVjdHR8ebanRLqykJz54RIiKittasMBIXF4dFixbh4MGD2L59OyoqKjBx4kQUFxc3+DyNRoPMzEzxdunSpZtqdEtj4TMiIiLpKJtz8pYtW6zur1q1Cj4+Pjh27BhGjx5d7/NkMhn8/PxurIVtgCXhiYiIpHNTc0Z0Oh0AwNPTs8HzioqKEBoaiuDgYEyfPh0nT55s8HyDwQC9Xm91a03sGSEiIpLODYcRo9GIJUuWYMSIEYiIiKj3vB49euCLL77Ahg0bsHr1ahiNRgwfPhyXL1+u9zmxsbHQarXiLTg4+Eab2SRiSXjOGSEiImpzNxxGFi1ahOTkZKxdu7bB86KjozFnzhz069cPY8aMwS+//IJOnTrh008/rfc5S5cuhU6nE28ZGRk32swmMfeMcDUNERFR22vWnBGzxYsXY+PGjdi7dy+CgoKa9VwHBwf0798fqamp9Z6jVquhVqtvpGk3RM2eESIiIsk0q2dEEAQsXrwY69atw65duxAWFtbsb1hVVYWkpCT4+/s3+7mthT0jRERE0mlWz8iiRYuwZs0abNiwAW5ubsjKygIAaLVaODk5AQDmzJmDwMBAxMbGAgDeeOMNDBs2DOHh4SgoKMCyZctw6dIlPProoy38Vm6c2DPC1TRERERtrllhZPny5QCAsWPHWh3/8ssv8dBDDwEA0tPTIZfXdLhcv34d8+fPR1ZWFjw8PDBw4EAcOHAAvXv3vrmWt6CaomfsGSEiImprzQojgiA0es6ePXus7r///vt4//33m9WotsZy8ERERNKx+71pAJaDJyIikhLDCGp6RjhMQ0RE1PYYRlDTM8Jy8ERERG2PYQQsB09ERCQlhhFwozwiIiIpMYyAPSNERERSYhgB64wQERFJiWEElnVGOExDRETU1hhGwJ4RIiIiKTGMgHvTEBERSYlhBCx6RkREJCWGEVgM07BnhIiIqM0xjMBiAit7RoiIiNocwwjYM0JERCQlhhGw6BkREZGUGEZQUw6+0iigsoqBhIiIqC0xjKCmZwRg7wgREVFbYxhBzZwRgGGEiIiorTGMAJDLZVApTJeilJNYiYiI2hTDSDXzvJHScoYRIiKitsQwUs1ZpQTAzfKIiIjaGsNINSeVaRIrh2mIiIjaFsNINXMV1hIO0xAREbUphpFqzuaeEYYRIiKiNsUwUs3JvD8Nh2mIiIjaFMNINQ7TEBERSYNhpJozJ7ASERFJgmGkGodpiIiIpMEwUs28tLekvFLilhAREdkXhpFq5jkjpeXcm4aIiKgtMYxU45wRIiIiaTCMVOOcESIiImkwjFRz5JwRIiIiSTCMVHM2zxmp4JwRIiKitsQwUs28mqaMRc+IiIjaFMNINfOckWIO0xAREbUphpFqGicHAICutELilhAREdkXhpFqni4qAMD14nKJW0JERGRfGEaqeTibekaKy6tQXslJrERERG2FYaSaxtEBcpnp64IS9o4QERG1FYaRanK5DNrqeSPXSzhvhIiIqK0wjFjwcK6eN8KeESIiojbDMGLBvXreCIdpiIiI2g7DiAVxRQ2HaYiIiNoMw4gF9+phmnwu7yUiImozDCMWPDhMQ0RE1OYYRiy4O3OYhoiIqK0xjFgwr6ZhzwgREVHbYRixYB6mYc8IERFR22EYseDhwjojREREbY1hxELNMA17RoiIiNoKw4gFy9U0RqMgcWuIiIjsA8OIBfNqGqMA6MvYO0JERNQWGEYsqJRyuKgUADiJlYiIqK0wjNTi7aYGAGTpyiRuCRERkX1gGKmli7cLAOD8tSKJW0JERGQfmhVGYmNjMXjwYLi5ucHHxwczZsxASkpKo8/78ccf0bNnTzg6OqJv377YtGnTDTe4tYX7uAIAUnMYRoiIiNpCs8JIXFwcFi1ahIMHD2L79u2oqKjAxIkTUVxcXO9zDhw4gFmzZmHevHlISEjAjBkzMGPGDCQnJ99041tD106mMMKeESIiorYhEwThhtewXrt2DT4+PoiLi8Po0aNtnjNz5kwUFxdj48aN4rFhw4ahX79+WLFiRZO+j16vh1arhU6ng0ajudHmNsnRi/m4e0U8ArSOOLB0fKt+LyIioo6sqZ/fNzVnRKfTAQA8PT3rPSc+Ph4xMTFWxyZNmoT4+Pib+datxjxMc1VXhmJDpcStISIi6vhuOIwYjUYsWbIEI0aMQERERL3nZWVlwdfX1+qYr68vsrKy6n2OwWCAXq+3urUVd2cVvF1N9UY4VENERNT6bjiMLFq0CMnJyVi7dm1LtgeAaaKsVqsVb8HBwS3+PRpinjfCSaxERESt74bCyOLFi7Fx40bs3r0bQUFBDZ7r5+eH7Oxsq2PZ2dnw8/Or9zlLly6FTqcTbxkZGTfSzBvGFTVERERtp1lhRBAELF68GOvWrcOuXbsQFhbW6HOio6Oxc+dOq2Pbt29HdHR0vc9Rq9XQaDRWt7YU7OkMAMjSs/AZERFRa1M25+RFixZhzZo12LBhA9zc3MR5H1qtFk5OTgCAOXPmIDAwELGxsQCAp556CmPGjMG7776LqVOnYu3atTh69Cg+++yzFn4rLUfrZNowT1/KkvBEREStrVk9I8uXL4dOp8PYsWPh7+8v3r7//nvxnPT0dGRmZor3hw8fjjVr1uCzzz5DVFQUfvrpJ6xfv77BSa9SM4cRHcMIERFRq2tWz0hTSpLs2bOnzrF77rkH99xzT3O+laQYRoiIiNoO96axgWGEiIio7TCM2MAwQkRE1HYYRmzQVIeRsgojDJVVEreGiIioY2MYscFNrYRMZvqavSNERESti2HEBrlcBo0jl/cSERG1BYaRenDeCBERUdtgGKmHh4tps7y7lsdjc1JmI2cTERHRjWIYqcfsISHi12uPtO3eOERERPaEYaQe9wwKQkwvHwBAAYdqiIiIWg3DSD1kMhnmjewCACg2VErcGiIioo6LYaQBbo6mavlcUUNERNR6GEYaIO7eW8YwQkRE1FoYRhpg7hkpqzCivNIocWuIiIg6JoaRBriqazY1LmTvCBERUatgGGmAUiGHi0oBACgs4yRWIiKi1sAw0ggN540QERG1KoaRRpjnjbBnhIiIqHUwjDSCG+YRERG1LoaRRnCYhoiIqHUxjDTCw9m0YV5ecbnELSEiIuqYGEYa4e1qCiO5hQwjRERErYFhpBHermoAQG6RQeKWEBERdUwMI43wdjMP0zCMEBERtQaGkUZ4uVT3jHCYhoiIqFUwjDSCwzRERESti2GkETXDNOUoq6iSuDVEREQdD8NIIzydVVDKZQCAJ75LkLg1REREHQ/DSCOUCjmW3t4LALD9VDZSsgolbhEREVHHwjDSBPNGhmFSH18AwG9JmRK3hoiIqGNhGGmiASEeAIBLecUSt4SIiKhjYRhpohBPZwDApbwSiVtCRETUsTCMNFGIlymMZOQzjBAREbUkhpEmMveM5BWX471tKRAEQeIWERERdQwMI03k5ugAmWmFL/69KxXRsbuQpSuTtlFEREQdAMNIMyy7O0r8Oktfhs/3XZCwNURERB0Dw0gz3D0wCC4qhXi/pJwVWYmIiG4Ww0gzGS2miiSkF7BEPBER0U1iGGmmSqNR/PpUph6xm05L2BoiIqL2j2GkmYI9nK3ufxV/SaKWEBERdQwMI8304X390dPPTbzvplZK2BoiIqL2j2GkmfoGabFlyWgk/H0CAKDQUImS8kqJW0VERNR+MYzcIA8XFdwcTb0iVwtKJW4NERFR+8UwchOCquePZFxnGCEiIrpRDCM3IczbFEYOp+VL3BIiIqL2i2HkJtwRFQAA+PHoZVRWGRs5m4iIiGxhGLkJ43v5wkWlQG6RARdyi/HuthRsSc6SullERETtCtel3gQHhRxdfVxx4rIO728/i83VQeTiO1MlbhkREVH7wZ6RmxTeyRUAxCACAIIg1Hc6ERER1cIwcpO6+rjWOaYvY90RIiKipmIYuUndfd3qHMsvLpegJURERO0Tw8hNGtejE56f1ANfPzIEvho1ACCvyCBxq4iIiNoPhpGbpFTIsWhcOEZ37wR/rRMA4EpBKd7bfhapOYUSt46IiOjWx9U0LcjbVQUAeGptIgDgu8PpOPJyjIQtIiIiuvWxZ6QFebqorO5fK+RwDRERUWMYRlqQl6u6zrHySlZmJSIiagjDSAuK6eVT59jZbM4bISIiagjDSAsaGOqJv07uiQEh7lDIZQCAP320DyXlrDtCRERUn2aHkb1792LatGkICAiATCbD+vXrGzx/z549kMlkdW5ZWR1zD5eFY7vil7+MwBO3hYvHDl3Ix75zuej72lZsTsqUsHVERES3nmaHkeLiYkRFReHjjz9u1vNSUlKQmZkp3nx86g5pdCRLYrrj7oFBAIC4s9fw+DdHUVhWiYXfHhfP0ZVWYFNSJsoqqqRqJhERkeSavbR3ypQpmDJlSrO/kY+PD9zd3Zv9vPbstp4++OnYZRy8kIfi8prAIQgCZDIZFq85jt/P5WL+qDC8PLW3hC0lIiKSTpvNGenXrx/8/f0xYcIE7N+/v62+raT6BmoBAGeyrCexZuSXAgB+P5cLAPg6/lLbNoyIiOgW0upFz/z9/bFixQoMGjQIBoMBK1euxNixY3Ho0CEMGDDA5nMMBgMMhpoaHXq9vrWb2SqCPJxsHk/IuI73tqeI9w1c/ktERHas1cNIjx490KNHD/H+8OHDcf78ebz//vv45ptvbD4nNjYWr7/+ems3rdXJZDKr+65qJYoMlWKFViIiIpJoae+QIUOQmppa7+NLly6FTqcTbxkZGW3Yupb13MTuAIB/3hWJN6b3kbg1REREtx5J9qZJTEyEv79/vY+r1Wqo1XWrmbZHC8Z0xdTIAIR5u+BibrHUzSEiIrrlNDuMFBUVWfVqpKWlITExEZ6enggJCcHSpUtx5coVfP311wCADz74AGFhYejTpw/KysqwcuVK7Nq1C9u2bWu5d3ELUyrkCPN2AQB09nbBu/dEQePkgC/2pSH+Qp543sXcYhy5mI+v4y/hL2O7Ykrf+sMaERFRR9LsMHL06FGMGzdOvP/MM88AAObOnYtVq1YhMzMT6enp4uPl5eV49tlnceXKFTg7OyMyMhI7duyweg17cld17ZHVB61X0Iz91x7x64XfHsfFd6a2ZbOIiIgkIxMEQZC6EY3R6/XQarXQ6XTQaDRSN6dFPPrVEew4nVPv41FBWjw2uiumRrKHhIiI2qemfn5zbxqJOCgavvR/XNbh4931T/IlIiLqKBhGJOLpohK/VsplNs+5lFcMc8fVkYv5SMniDsBERNTxSLKahoCnYrohIb0AMwcHY1Q3b8SdvYbX/3fK6pzi8irkFpWj2FCJe1bEAwDOvTWl0V4VIiKi9oRhRCI+bo7Y9NQo8f7l66Xi1128XXApvwRVRgGD39ph9byE9AIMCfNss3YSERG1Nv6KfYuwLB2/89kxCPV0tnnetpNZbdUkIiKiNsEwcovo0skV78+MwjfzhkAmk2F6v0Cb532xPw2nM9vnXj1ERES2MIzcQu7sH4RR3ToBAJ64LRzRXbzExz68rx+GhnnCKJiGaoiIiDoKhpFblFwuw2CLuSHRXb3Q2ctUyTWvyFDf0+r18e5UrIg732LtIyIiaimcwHoLUylqlvx2clXDy9W0HDivuLxZr5NXZMCyrSkAgFlDQqB1cmi5RhIREd0k9ozcwszzRoZ18YRMJoOXq2nzwAPnc/HaryeRW2RATmEZjMaGi+hm6srEr69YrNohIiK6FbBn5BYW7OmMwy+Nh9bZ1JPhXd0zcja7CGezi7DqwEUAQJi3Cz6+fwB6B9gutXuloCaAXL5eUu95Zul5Jci4XoIR4d4t8C6IiIgaxp6RW5yPxhFqpQKAddVWS2m5xVi85ni9r5FpEUbMwaS80ghdaYXN80cv243ZKw/hePr1G202ERFRk7FnpB3xclHX+9iF3GLoyyrww5EMyGUy+GjU6O2vQZdOrrhqMUxjLq62YPUxHE7Lx/pFI/C39UlwUMjh6KCwmhx7IDUXA0I8Wu8NERERgWGkXTEP01hyUMgggwzlVUb8d+8FfLSrZnM9D2cHBHk4I+mKTjy283Q2Fozpil1nTDsGP/19otXjlsoqjC38DoiIiOriME074umiQqC7k9WxiEAtOnubqrX+cvyK1WPXSyrqBI2LeSV48ecT4v36gggA6MtsD+MQERG1JIaRdkSpkON/T4zE3ufHicdmDgqGv9YUUCwnqtoyvV8AAGBnda9IY64W1AzvVFYZ8cGOszh0Ia+5zSYiImoQh2naGU8XFTxdVPjn3ZFIyy3GPYOCsf+8dUDY/+JtiEu5hrd+O4Xi8irx+B1RAdiQeFW8P72f9f3aUrL1+PfOc4g/nwe5HNifmocPcA4X35naaDurjAIUclmj5xEREbFnpJ26d1Aw/jq5JxRyGZwdFOLxFQ8MRKC7E+4fGoKk1ybhsdFdAACRQVpEBGqtXuPZCT3goqp5rkph/c8hI78U720/i/gLedifWhN4Zn4aj2x9Gerzxv9OYcCb25GRX3JT75GIiOwDw0gH8GRMN4wM98a3jw7F5Ag/8bhcLsOzE7vjzRkR+HzuYPi4qRHqZZpfEqB1RLCnE3Y+OxYrHhiAtNjb8cLkHk36fofS8nH/fw9aHbt8vQSxm04jt8iAL/anQVdagZ+PX265N0lERB0Wh2k6gEB3J6x+dKjNx9RKBR4cFire/3nhcKz8PQ3Du3pBJpPBT+uIyVp/AEAnN9tLh1/5U2+8sfGU1bHz14px/loRunZyxZ6UHDz05REAwPdHM8RzHKp7Wg6cz0W4jyt83Bxv/E0SEVGHxTBiZ7xd1XhxSk+bj/X210AmAwSL6vJuaiUeHtEZmbpS/Pf3NKvz715+APNGhuFf286KxwpKalbgXCs0YNeZbDyy6igC3Z2w/8XbWvbNEBFRh8AwQqJuvm7Y9OQoKOQyuDs7YNmWFDw6qgtkMhlentq7Thi5XlJhFURqy9KViRNkG1vp0xSXr5fgxZ+TcO/gYNwRFXDTr0dERLcGzhkhK738Neju6wYfN0csuycKPfzcxMcWjesKAHh+Ug/cOyioznN7+2uwZv5Q/OOuvgCAHaezG1ytU1hWgYIS2zsQl5ZX4YWf/sDO09nisXXHr2Bfai6e/C4ByQ3URyEiovaFYYSa7OmY7tj4xEgsGNMVQ8K8xOODQj2w8YmR+PbRoRje1Ru9/E0b8VXW2k24rKJmmbHRKCA6dheGxe60Om728e5U/HD0MuZ9dVQ8lpZXLH597NJ18TW3n8pGSXnlTb03QRAQd/YarhUaGj+ZiIhaFIdpqMmUCrm4PLi3f83Ov/cPDbFaNly7SqxZXnE5HOQyXMwrQbCnE4oMpgBx/loRQjyd4ebogJzCMjjI5Th5tabnI/mKDu9tPyuWsAeAczmFqKgy4q3fTuObg5cwc1AwZg0NwebkTDwd0x2O1cudBUHAv7alwE/jiAejO9dpU2JGAR7/5ij8tU5IzChATC8frJw7+MYvEhERNRvDCN2Q7r6uGN7VCyqlHNP7BVo95uWqxtt39sVr/zuJ8sqa/W3yi8rx5sZTOHwxHw8N7ywen/rvfXBTK/HVvCF46IvDcHRQWBVMm/fVEWTrrXssVh9MR5VRwHeHTat3vj+agZ+OX0aVUUCJoQqPjAyDrrQCTg4KfLz7PAAgW2/AkphuUCrkOHYpH39ffxKnMvXiYwCw43QOxizbjckRflg6pVfLXTAiIqqXTBAEofHTpKXX66HVaqHT6aDRaBp/At0SBEHAiz8nict9v3x4MB6uXgLcmpxVCriolcgvLsezE7vjn1tSxMfeuzcKfx4QhG4vb0JFVcP/9L98eDDG9fBp0bYJggCZjJVpicg+NPXzm3NGqNXIZDL84+5IjAz3BgDsbuKeOLb4aRwxvKsX/DSN1yopKa/CtUIDqoyCVRABgL1nr6G0vKrRIAIAD395BMcu5d9wm2vTl1Vg9LLdWLI2ocVek4ioI2AYoVbn5aoCAHwdf6nec6ZYVI41s5x70t3PDWvmD8N/5wy6oTaYC7rtS83F5/suNPl5xy8V3ND3MxMEAb+fu4bdZ3KwPuEKMvJLsT7xKlqyQ7Kyyoidp29+Ei8RkVQYRqjVNWWo4/ExXcWvn5vYHctnD8BT47uJx3yrw4SPxnaVWLPBnT3w04JovHVnBGYNCal5/dFd4O2qRm5RuVVtFGeLvXlsaWxERRAEHL2YbzMI/Hg0AzM/PYgHPz+MeV8dwR8ZNZNy9WXW58edvYaVv1+A0WIFUl6RAVm6+vcAMnvl15OY99VRLNua0ui5bS01pwhpucWNn0hEdo0TWKnVzegfiKmR/njw80M4eCEf0V28cPhiPqosPng7eznjn3dHYveZHDwyMgzOKiXizl4THzeHEC8XlXjsj1cnYlNSJpb+kiQeG9vDB4M6e2JQZ08YKqvw3eF0AKZeln/P6odHVh1BWYURUyL88MF9/VBWYURukQEHzufhyvVSrIg7b9X23KKaOigLVx/DlYJSrH50KJ5Yk4Aefm4YEe6NuV8cRmSQFj8vHI5DF/Lh5qiEg0KO5386IT7XKABbkjPF+9n6MmidHACYejbmfnEYgKkHZ3q/QJy6qsddyw/AQSFD/NLxcFGb/qsajQJkMojzTt7ZfAZrDpne45f7L+LVaX2a/fdzIzYkXsHBC3l47Y4+UCttB7piQyVi3osDAJx7a4q4PQARUW0MI9QmHBRyfP3IUBy8kIdBnT3goJDjrd9OY9WBiwAArZMD7h0UjHsHBYvPsZwf4lv9tVIhx7anR6OySoDWyQGzhoTgTKYe6xOv4rmJ3TF7aM0+PGqlAiseGIjDafmY0NsXSoUce54bh8SM67itpy9USjnUSgW0Tg7o2skVx9Ov1wkjOdW7E2fry7A5OQsA8K+tKYg7ew1x1fNPAODEZR3uXn4Af1zWQSYDRnXrVOcaFJfX1FOZ+P5edPNxxbyRYejSyVU8/v2RDEzvF4i3N51GaUUVSiuAPq9uRW9/DSICNVifcBUDQz3w9z/1hqGyyqq9IZ7OzftLuQlPrU0EAIR5u+Cx0V1tnmPZI5KlK0OwpzNKy6vg1EhvFBHZH4YRajMqpRyju9d8SFdU1Sz7tbXCxDKMmHsRAKC7r5vVea9Pj8Cr0/pALq/7GpMj/Kx2MrbcGLC2fkHu8NM4IktfMzRyLqcIxy5dF0MJYD335Xj6dfHrPy6bhmEEwTRRtjHncorw4i9JcLH4cD54IQ9pucU4eCHP6txTmXpxGXL8hTzc/u/f67xeen4JNiReqbPUuinOZhcixNNZrM9SH11pBZ7/8Q/x/tGL1/HYaFPF3FOZegwIcRf/Li9aFKm7WlCKgxfy8MLPJ/CPuyKtQmdrWPn7BRQZKrEkpnurfh8iahkMIySZXv4NL9PWONX886xvKMDMVhBpLrlchi1LRmHrySz89WfT0E/SFR3uWn6g3raevKq/6e9r2WNiFIBx/9pzw6/11NpETI7wa/R6Hb2YjysFpfByUSNTV4rnfzqBx0d3wV8n96xzLQ2VVcgrKsfGE1fx9qYzVo+dztJDEAS8sfEkvjucgUdGhOGVab0BAGnXasJI3Nlr+GSPqRfnhZ9OIFtXhnsHB4s9Xi1BEAQs25oCV0eluIpqcoQfevrdXDmAIkMlvjpwEbf39UeYt0tLNJWIamGdEZJMZZURn+69gJHh3ogKdrd5zpsbT+HE5QKsfnRoox+wLelKQSlG/mMXmvq/4717o1BQUoHR3b0R897eZn8/lVKOBWO64t87zwEAlHIZHBRylFaXyv987iCr0viWOns542JeiXh/05Oj0Dug/v8nlVVGhL+82eZjbmolZg8LtdrZ+fkf/8BPxy/Xey3u7B+IdQlXxPu/vzAOwZ7OWLj6mDi0Zcvgzh74ccHweh9vrlNX9XV6jD68r1+dnqLKKiOUzZi/8tqvJ7HqwEV0clPjyMsxLdJWInvR1M9v9oyQZJQKORaNC2/wnL//qXcbtcZaoLsT/vvgIHy+Lw3xtYZManNRKfCnyAColE37gBvVzRtBHs7i5FoAmNrXH0/eFo6oIC30ZRXoH+yBr+Iv4sv9FwEA0V296nk1wFll/d/4TJa+wTBiHu6xpdBQiU/3nsfUvv7w1ajhqFLgx2OXG3w/lkEEMJXvzyk0NBhEAODIxZr9hRobHmqKlOy67+tUph7T+wVCEARUVAn4z65z+HTvBbx2Rx8kpF/H85N6isu+AeCjnedQYRTwdEw3cbjJvA3BrbRvUaauFIvXJOCh4Z0xjTtYUwfAMEJUj5jevhgc5omo17eJx468HINObmoUGSrxdfxFFJZVYnIfvwaDyITevigoKceRi9fxzITueLJ6yfL/zYjAtI/2Ib+4HEtv7wmlQo7xvXzF51mW0ndWKTG1rz9+S8qs8/oB7o5WAeNMVmGD7+twWsOF3AQBmPaffQ2eAwBujkrcNSAI6xOvoLJKgEIug660Aqcy9TjRxF2VX1qXhO+PZOCHx4dhYKhno+dfLSjF6Uw9buvpU2eeUfKVumHk4IV87EnJwUO1Kv+aV2AVl1fh4/sHADDN/3l3u2nZ9/R+AehaPbHYctTqgZWH8NUjQ6y2K2iM0Shg68ksDAj1QEl5FT7bex4LxnRFqNeND/m8/uspHLt0HccuXRfDyJksPdydVPDTWg996csqUFpe1aJDYh1ZRn4Jdp3JwczBwS0SkqWUX1wOdyeHFhnGbm0MI0QN0Do5wM1RicLquiDe1QXcXNVK/GWs7V6dEeFe2J+aBw9nB6yZPwxh3i6oMgrYkpyF2/vWTJ5VyGX43xMjUV5ptLnC5Pa+/vj2ULo4T+Hde6PwyMgw3LX8AABgYKgHPJxVeO2O3thxuqa6bWpOUZ3XKimvxH/3pmFalD8ONRJGmuqbeUPRL9hdHNL5/kgGXv31JHaczkFekakXYUlMN2Try7An5Royq2umhHg6Iz3fNKxkXpa843ROo2Fk79lrmFO9BPq/cwZhUKgHnlybgMkRfpg9NBQnLhfUec4fGQV1goil305kYvE4PXr5a/D572ni8fHvxuGLhwbhtp6+sByd2peaiyvXSxHi1fSVS98dScfL65LRy18DJwc5jqcX4LvDGRjf0wePje6CoV3q9noJggBDpbHeD0PLycEAcORiPu5ZEY8evm7Y+vRoq8fmfnEYZzILsePZMfVuYtkcGfklWJ9wBXOiO0PrXDOx/D+7zqHSKFhNGq6sMmLryWyUVlThzv6BzQpxbeV0ph7nr5n+z4zu3gn3rzyIjPxSXNWVtuv9qZIu6zDtP/swvV8APryvv9TNaRQX/hM1op/FfJam7Cuz7O4o3D0wCKsfHYpe/ho4Opj2yrlrYFCd0KGQy+pd6joi3Bs/L4zGzwtN8yocHRQYGOohFnN76faeWDl3EII8nPHpgwPhX/0b8a4zOZi98qDVaqW/rUvG+zvOYuHq4zhy0TqMvHR7T/x7Vn+M7t4J3z46tM4HRkwvH+x8dgzGWKyE+mT2APG6ODoo4OigQGSQaefm05l65FQPafQJ0CL2z5H4/rFoqJVy3BEVgL0vjMPkPtYVd/WlFQ1e05NXdWIQAYDdKTn4aFcqfj+Xi5fXJePOT/aLwz7NNeXD3xG7+TQ2J1v3Oj2y6igOpObWKTx34HyuzdfZdjILU//9O1Jq9Ux9fcC0+up0ph4nLtf0GO08k4P3d5yFLa9sOInI17fh5XVJuFpQWufxEotJzzM/jcc9K+IBACnZhUi+osP0/+zDD0cykJhRgIT0ApRWVGHX6ez6LkGzPP7NMby7/SyWrqupo3Mprxj/2nYWH+w4h5zCmuu1IfEqFq05jud+/AM7Gvn+OYVlWPn7BZuhsjVN+fB3LF6TgMVrEvDyumRk5Juu9/paw4/tTdxZ0y8oGxKv4ujFltvWorWwZ4SoEW/f2Rfzvz6KOdGdm3R+gLsT/nVPVIt8b1u9Bf83IwJPT+gGH7eabvdJffzQxdsFE943TZ7dn5qH1QcvIcjDGdtPZeGX6h+sKdl1h3DMdULuqO7u/+HxaGTry6BWyjGqWydxCOqT2QPQ59WtAGAzQPULdrfqRQJqap+EeDnj2N8nwLn6N/3bI/2x5WTNnJL6Ks3uO5eLIkMFrpdYhxVzj4pZQnpBned6u6qsitaplXIYKo14bHQX9PJ3w9Pf1yxR/jTOtEWAh7OD1fe6f+WhOq/74i9JcFIp6kyMfeybYwCApb+cwC9/GQHANLcj43rN5OJKo/UsYPNu0UajgM/3pWFAqAcig7T45qApwHx7KB3fHkqHTGbqKZszLBQ/H78s9iwBqNPT9dyPf+BMViH+uHzC6vjpeobvDqTmIsDdCZ1trBS6WlCKD3acxYIxXeHtpoabWikOCW5Kqvn725NSs5Q9S1eGtGvFeHl9slUv3YVrtivxVhkFfB1/Ea//75R4bHBnD3z58BC4qlv3I8pY6+/jf39cFb/Os/i305Ckyzr8nnoNE3v7ItzHrfEntBHLVXpvbzqNnxcOr/PLlCAIEISWWY14sxhGiBoR7OmMLUtGN35iG1HIZVZBxCygVhe85Q/32mJ6+eLugYHivAhLA0M9bD7HxeKDoZtP3efJZDJ8MnsAHvy8pgcj2LOmTZYfLDG9fKyGazItwsivf1zF1uQsvDkjAg98XjcM1MdNrcSSCd3x5kbT+x4R7o0NiTUfLodeGg9DpRGeLio4KOSQy2Q4lanH4bR8JKQXwEWlwI8LhkPjpISLSokXf0kSP5xGhHvh1FW9GFSeWptYbz0Xc8DYkpyFBauPNdjmtNxifHXgIqqMAt7adBoAbM4/EgTTkNJvJ+rOGaqtvjlDxy/V7Tk6dVUvBq6L70wVj5eWVyH+Qi4+2pWKhPQC/HD0MmQy4LHRXaye//HuVGw7lQ2lxYfZztM5uFpQWme4MEtXitjNp3ExtxivTuuD0ooqeLuo8fam0+LO3mZHLl7HuoQreHBYKG6EeYLvnOjQOn9PgiDg5FU9unZyRXED+zmZg2NqThH+tTUFMwcHY1xP660tig2VeHjVEeQWGfDPLSmYNSQEf5vay+r/SnO8ufEUTl3V4+t5Q8SKxVtPZkEQBEyOsF0fqT45+poJ18fTC/Dp3guYEx1qNeH94VVHkJFfgt+eHCX5/BiGEaIOoqEfgC9M7oHV8ZdwVWcqQ//2nRHwuYEJjbufG4v84nIEedieMzGqWyfEPT8WY5btQYinc52VPmbOKiW2PT0aZ7MLccd/9uNUph5LfzmBucM748nvTLsaN2V+wb2DgvDTsctY8cBATKwe+jGHkbKKKkQFafHHZR3cnR3g7qyyeu70foGY3i8Qhsoq7E/NRWcvF6tquB/N6o87+wfgXHYRHhrRGf+38bTYYwEAm5MyMaWvP9Jyi8UVN4BpWfjhtHw81cDuzL4atRhaXv31pNVj5onL0/sFYNeZHKuepubwdlVBLpNh5dxB+PMnB3AmqxBnsvRWdVcSMwrEryurjNidcg3fHLwEtVKO7aesh1UEoaYHyczWfkgfVi9Pr+14egGSqic2bz1pCjCd3NRWQdTS4bR8XC8ux4PDQuHhorJ5Tn3e23ZWnOBrDiNVRgF7z13Dleul+Nv6ZMT08sHTExouirch8YpYbXjLySxse3o09qfmYn9qHgRBwPlrRcgtqvnQ/+5wOr47nI4Xp/TEgjG2KxPXRxBMvWOAaW7SuB4+yCksw+PVPW7myfOW5//15xMQBOCfd0dCJpMht8iAHL0BvQM0VsNlgGnriMvXS+CnccR3hzOw6uHBYo/W8fTrGN7Vu1ntbWkMI0Qd3P1DQ/CXseEY3tUbm5Iy8fjoLvBybXjDwfqEebs0Wvgr1MsFu58bCxd1w79pOTooEOpZ81rfHc7Ad4drfkP+1aLL3JbnJnbH4tu6IfbPkTaDS99ALf51TxTe234W/UNs9/YApoJ6t/X0tfnYbT19xcfuGRRkFUYWf5eA+8/nWR0zu/fT+Abb/tOC4Rj1z931Pt7Tzw0vTO6Jh0eEYcbH++s8HurljEsWtWX+PCAQj4/uigWrjyEttxivTeuNWUNDIAim6zyxjy82JWVh3fErWHp7TRgxWhSPySk0YP7XtmvZtISkWiusKo1CvUEEqBky2ZSUiU1PjmrWUEKOxTLsjPwSrD50Cdf0BnG4EjBNmr5/aIitp4vMQcRszaF0cQuLhryz+YzNMGI0CrhSUIogDydUGQUs+T4RfhpH/O1Pva2GVQzV9YX2p9bMT0rMKMCE3jX/TjPyS/HDUdOye11pBf51bxQeWHkIZ7IK8eF9/cSeEbnMVEwRAFYfrBneNA/pAk0fkmpNDCNEHcikPr7YejIbr9/RB+WVRsweFiL2TvQLdreajNuamlqpVOOkhJtaiUJD47/99w3U4va+/vjHFlMV2LsHmkrK1w4iO54Zg22nsvDw8DA4qRQttnlgZJA74p4fC08XFR5ZdQRHLl63GUSawruRMGgeFgx0d8KjI8Owcl/NSp+oYHdsWDQCnV/8TTz23r39AACbnxqFbH1ZnWXDQ8O8sCkpy2r+yo5T2fjDomfkio2JsgDgoJChoqr1a2P+KdIfG20MQ53JKkRusaHO0GTs5tNQKeR4cnw3bD2ZhYgALTp7u+DCtSKrTTaf+SGx3snNKVmmoaTbevpApZBbzWOyZF4hZyuIRAZpERmktfqgB0yVi2sXavzXthR8suc8VjwwAB7OKvH9Pj6mq1UvlXmC8u9na8JIQvp1eLmq8Ld1yXhzRoRV3Zttp7Jxz/J4cU7Y39Yli+GtX7A7jtuYU2XpwrViZOnK4OWqkmxDS66mIepAPpjZHxufGIk50aGYP7pLvcMktwqZTIa+1atwzCKDtFZzEMx+XTwCC8d2xcf3D8AXDw2qU0/DLNzHFX8ZG94qG/KFernAzdEBD9SayzBriO29doaGeeLtO/vWOe6kUmB4dSG7h4Z3hqqBD4DaC7h6+VlPkrQMY44OCpv1S8zd++YPsPjzeXj066NWBe3mWMz1MXtmQnecfmMy/nl3JL58eLBVO3c+OwYPDe9cb7vNc48s915yUNTfu/Hm9Agsnz3A5mMPrDyE89eK8Pm+NBSWVSAlqxCfxl3AR7tSMeofu7F4TQKerB4Wq72Uu6FVVuZg66tR4917bU8693B2wDt/jqz3NTq5qvHC5J64a0CQ1XHLnivANKxi3hLh/e3nrFZWvbP5jFWv1DM//IHkKzqrSceJGQWY9dlBnMrU46EvD9cpXmg5Ob3QUAld9Qq1yCD3ettu9v6OsxgWu1Nc4iwFhhGiDsRJpUBEoLZJS5BvFbWLcQ0M9cCa+cOw7i/Dsff5cXBTK/HAsBDxPU2N9K93WKWtBFvskDy5jx9i/xyJAy/eBudaAeieQcG4f2gIPryvX53X+M/9A/Dzwmi8dkcf7HpuDOaPCoNchjrnDu5svaLKvOrJ/GE/f5T1pFJbLMOIIAg2l3qatx4we3RkGJ4c3w1KhRz3DgrGuB4+6O5XM6emaydXvHZHH6vgODe6JqQ9Nb4bDr00HgeWjkdML18M6+KJwy/FYHT3Tpg3MgwqpRyODnLMHxWGj2b1h4eLClP6+uP43yfATa1EiKczelYHr7PZRRj/bhze3HgKfV/bhkkf1AwxmDe2PHFZh4KScquVRpYm9/Grt6JzqJcLXNRKeNqYmxLs6YxgT2f8KdL2BFJPFxU0jg54994obHpylHj8Qq0Pdsvw4KNRW22y+fPxulWO//TRPlyyqCdzNrsIhur5RIVllUiqZwm0ZY4f0tkTGotNRhtzvbjhJfat6db+tYmIOrzR3b2tSsr7uDliSFjNB3DCKxOatZdMWwi1CCOBHqYVQwHuTtjxzBgUGSqxNTkLh9LyMamPKTRN7xeIkeHemPbRPnFVhKeLCp4upvcZ5OGMl6f2xrMTe9RZ1TChty/euzcKrmolnFVKDA83TTT8ZPYAxKVcw4z+je/S3Kl6WOhiXgn6vbEdTk1YORHkUbdA2vOTemLuF4cxxWInbK2TA/KKTXMOXp8egTsHBCFbX2a1Q/fKuYPEr79+ZAgAYObgYLiqlXVWgXm6qLDzuTFQKxR49sfERisKW4qO3VXvYx/PHlBnN2zA1Hs1t3rZvtbJAfnF1vMnzMHztTv6QOvkgP4hHgh0d8Ks/x40tde1JsD0DtDgz/0D8UvCFZzNLsLkCNPx1JxCzFtV0/ORV1SOc9mN90JY1pOxnCgLALurJ59ufGIkPthxDjtOZyNA64hp/QLEicYf3NfPqrdDIZfhnT/3hUopt5oPMyLcC189PETS/2cMI0QkqRn9AmGoMOLF6hLtXq7Wv53eakEEgNVv0JZj7OYP1u6+bnii1nO8XNXY/+JtDfZa2VpeKZPJ8OdaQwCAqUfp3sG2h4dqs1yFoSutELvwzf7cPxCTI/yw5WQWfjluCoaWvT9mY7p3wo5nxsBXU/N6c6I74/0dZzG4s6mnpqnzkrr71l+Twzw/xMO58VU0Yd4uKDZUIqfQIPbuBHk44eP7ByAluxAv/HQC43p0gkIuQ6hF5VxHBzkqqgT8465IcUhvRr/AOoXozMHT21WNt6qH3M5aDIl41epN6R2gwS8JV/B1/EX8dOwyXp/eB6//elLswQFqekmUchmGh3tjr8Ucl7rXQm01Ibe2PgEafDJ7ANYeSUdUkDv8tY64pjdg9rBQBLg7wV/riOWzB6BLJ1d4uqjEfwtje/jg4IU8jAz3hpODQvJaIwwjRCQpmUyG+4aEoKC0AvHn88RhiFuZTCZDgNYRV3VlmNin6UNGUg2f1bfs+43pfRDu4you6xwS5imGkfr2sgmvVWNm4diu6NLJRZwD05IMFvsz2TIo1APfzh+Kz/el4Z9bTMuMB4S448P7+iPY0xlRwe4Y3NlT/AAO8nDGy7f3gpNKgWFdPFFSXmVV7G3h2K7w1agR3dUL3x/JwNWCUpsrbiwDiLbWMIh5joap4F45Hm5gO4LO3i747MGBuJRXYjX0ZMnDWYXCsso6w2gA8NfJPSGTyaBSyqyKMr43s5/4tUwmw5S+dYeYtE4OmFSrErKUGEaI6JawYEzXZtdmkNKGxSORcb0EAxpYNnyrG9Wtk9XKJ62TA8J9XKEvragTOuqjUspbbedgN8eaj6g1jw7FkYvXrXouunZyhVqpwPxRXTC6Wyf09tfU+Q2/9squ+aPrn2OjUspxX/V2Cy9M7lnveZY1a2oHzD4BGqvltLXfjwyAvrp2zJjuneDooEAPPzcsHheO/+xOBWAKfG/f2RevbEjGE7d1g6ODHB/sOId/3h2JyioBpzP1uKNfgOSFyloSwwgR0Q3o5Ka2Gv641ZmXzs6NDsX2U9noH+qBzrU2/JPJZNj4xEgYBeGW+KB74rZuOHlVj/uHhmB4uDcGhHrgj8sFYpE58/5LDgo5IgK1Db1Ui7JcwVS7B8lFrURPP02d1S4OChk2PTkKBSUVOHnVtJJmssXcm2cndsec4aEwVBjh4aKCq1ppVfnZckfv2ivQOgKZIAitv4D8Jun1emi1Wuh0Omg0msafQEREViqqjMgtMsBfe/M790rNXGNl8bhwPDephyRt+O1EJk5n6vHsxO51ekeW/pKE7w6b6o48O6E7Tmfp8eCwzohuhaGsW11TP7/ZM0JEZAccFPIOEUQA4IuHBuHXxKt4fEzjy5pby9RIf0ytZ7nvpD6+Yhh5Yny3tmxWu8WeESIioha2KSkTIZ7ObTp8dCtizwgREZFEbrexgoXqd+st4CciIiK7wjBCREREkmp2GNm7dy+mTZuGgIAAyGQyrF+/vtHn7NmzBwMGDIBarUZ4eDhWrVp1A00lIiKijqjZYaS4uBhRUVH4+OOPm3R+Wloapk6dinHjxiExMRFLlizBo48+iq1btza7sURERNTxNHsC65QpUzBlypQmn79ixQqEhYXh3XffBQD06tUL+/btw/vvv49JkyY199sTERFRB9Pqc0bi4+MRExNjdWzSpEmIj4+v9zkGgwF6vd7qRkRERB1Tq4eRrKws+PpabyTl6+sLvV6P0tJSm8+JjY2FVqsVb8HBTduZkoiIiNqfW3I1zdKlS6HT6cRbRkaG1E0iIiKiVtLqRc/8/PyQnZ1tdSw7OxsajQZOTrZLE6vVaqjV7WcDKiIiIrpxrd4zEh0djZ07d1od2759O6Kjo1v7WxMREVE70OwwUlRUhMTERCQmJgIwLd1NTExEerppU6ClS5dizpw54vkLFizAhQsX8MILL+DMmTP45JNP8MMPP+Dpp59umXdARERE7Vqzw8jRo0fRv39/9O/fHwDwzDPPoH///njllVcAAJmZmWIwAYCwsDD89ttv2L59O6KiovDuu+9i5cqVXNZLREREALhrLxEREbWSDrVrrzkvsd4IERFR+2H+3G6s36NdhJHCwkIAYL0RIiKidqiwsBBarbbex9vFMI3RaMTVq1fh5uYGmUzWYq+r1+sRHByMjIwMDv9Y4HWxjdfFNl4X23hdbON1sa2jXhdBEFBYWIiAgADI5fVPU20XPSNyuRxBQUGt9voajaZD/eW3FF4X23hdbON1sY3XxTZeF9s64nVpqEfE7JaswEpERET2g2GEiIiIJGXXYUStVuPVV19l6flaeF1s43WxjdfFNl4X23hdbLP369IuJrASERFRx2XXPSNEREQkPYYRIiIikhTDCBEREUmKYYSIiIgkZddh5OOPP0bnzp3h6OiIoUOH4vDhw1I3qVXt3bsX06ZNQ0BAAGQyGdavX2/1uCAIeOWVV+Dv7w8nJyfExMTg3LlzVufk5+dj9uzZ0Gg0cHd3x7x581BUVNSG76JlxcbGYvDgwXBzc4OPjw9mzJiBlJQUq3PKysqwaNEieHl5wdXVFXfddReys7OtzklPT8fUqVPh7OwMHx8fPP/886isrGzLt9Kili9fjsjISLEAU3R0NDZv3iw+bo/XpLZ33nkHMpkMS5YsEY/Z63V57bXXIJPJrG49e/YUH7fX6wIAV65cwQMPPAAvLy84OTmhb9++OHr0qPi4Pf7ctUmwU2vXrhVUKpXwxRdfCCdPnhTmz58vuLu7C9nZ2VI3rdVs2rRJePnll4VffvlFACCsW7fO6vF33nlH0Gq1wvr164U//vhDuOOOO4SwsDChtLRUPGfy5MlCVFSUcPDgQeH3338XwsPDhVmzZrXxO2k5kyZNEr788kshOTlZSExMFG6//XYhJCREKCoqEs9ZsGCBEBwcLOzcuVM4evSoMGzYMGH48OHi45WVlUJERIQQExMjJCQkCJs2bRK8vb2FpUuXSvGWWsSvv/4q/Pbbb8LZs2eFlJQU4aWXXhIcHByE5ORkQRDs85pYOnz4sNC5c2chMjJSeOqpp8Tj9npdXn31VaFPnz5CZmameLt27Zr4uL1el/z8fCE0NFR46KGHhEOHDgkXLlwQtm7dKqSmporn2OPPXVvsNowMGTJEWLRokXi/qqpKCAgIEGJjYyVsVdupHUaMRqPg5+cnLFu2TDxWUFAgqNVq4bvvvhMEQRBOnTolABCOHDkinrN582ZBJpMJV65cabO2t6acnBwBgBAXFycIgukaODg4CD/++KN4zunTpwUAQnx8vCAIppAnl8uFrKws8Zzly5cLGo1GMBgMbfsGWpGHh4ewcuVKu78mhYWFQrdu3YTt27cLY8aMEcOIPV+XV199VYiKirL5mD1fl7/+9a/CyJEj632cP3dr2OUwTXl5OY4dO4aYmBjxmFwuR0xMDOLj4yVsmXTS0tKQlZVldU20Wi2GDh0qXpP4+Hi4u7tj0KBB4jkxMTGQy+U4dOhQm7e5Neh0OgCAp6cnAODYsWOoqKiwui49e/ZESEiI1XXp27cvfH19xXMmTZoEvV6PkydPtmHrW0dVVRXWrl2L4uJiREdH2/01WbRoEaZOnWr1/gH+Wzl37hwCAgLQpUsXzJ49G+np6QDs+7r8+uuvGDRoEO655x74+Pigf//++O9//ys+zp+7NewyjOTm5qKqqsrqHz4A+Pr6IisrS6JWScv8vhu6JllZWfDx8bF6XKlUwtPTs0NcN6PRiCVLlmDEiBGIiIgAYHrPKpUK7u7uVufWvi62rpv5sfYqKSkJrq6uUKvVWLBgAdatW4fevXvb9TVZu3Ytjh8/jtjY2DqP2fN1GTp0KFatWoUtW7Zg+fLlSEtLw6hRo1BYWGjX1+XChQtYvnw5unXrhq1bt2LhwoV48skn8dVXXwHgz11L7WLXXqK2sGjRIiQnJ2Pfvn1SN+WW0KNHDyQmJkKn0+Gnn37C3LlzERcXJ3WzJJORkYGnnnoK27dvh6Ojo9TNuaVMmTJF/DoyMhJDhw5FaGgofvjhBzg5OUnYMmkZjUYMGjQIb7/9NgCgf//+SE5OxooVKzB37lyJW3drscueEW9vbygUijqzubOzs+Hn5ydRq6Rlft8NXRM/Pz/k5ORYPV5ZWYn8/Px2f90WL16MjRs3Yvfu3QgKChKP+/n5oby8HAUFBVbn174utq6b+bH2SqVSITw8HAMHDkRsbCyioqLw4Ycf2u01OXbsGHJycjBgwAAolUoolUrExcXh3//+N5RKJXx9fe3yutji7u6O7t27IzU11W7/vQCAv78/evfubXWsV69e4hCWvf/ctWSXYUSlUmHgwIHYuXOneMxoNGLnzp2Ijo6WsGXSCQsLg5+fn9U10ev1OHTokHhNoqOjUVBQgGPHjonn7Nq1C0ajEUOHDm3zNrcEQRCwePFirFu3Drt27UJYWJjV4wMHDoSDg4PVdUlJSUF6errVdUlKSrL6gbF9+3ZoNJo6P4jaM6PRCIPBYLfXZPz48UhKSkJiYqJ4GzRoEGbPni1+bY/XxZaioiKcP38e/v7+dvvvBQBGjBhRp1TA2bNnERoaCsB+f+7aJPUMWqmsXbtWUKvVwqpVq4RTp04Jjz32mODu7m41m7ujKSwsFBISEoSEhAQBgPDee+8JCQkJwqVLlwRBMC0xc3d3FzZs2CCcOHFCmD59us0lZv379xcOHTok7Nu3T+jWrVu7XmK2cOFCQavVCnv27LFallhSUiKes2DBAiEkJETYtWuXcPToUSE6OlqIjo4WHzcvS5w4caKQmJgobNmyRejUqVO7Xpb44osvCnFxcUJaWppw4sQJ4cUXXxRkMpmwbds2QRDs85rYYrmaRhDs97o8++yzwp49e4S0tDRh//79QkxMjODt7S3k5OQIgmC/1+Xw4cOCUqkU3nrrLeHcuXPCt99+Kzg7OwurV68Wz7HHn7u22G0YEQRB+Oijj4SQkBBBpVIJQ4YMEQ4ePCh1k1rV7t27BQB1bnPnzhUEwbTM7O9//7vg6+srqNVqYfz48UJKSorVa+Tl5QmzZs0SXF1dBY1GIzz88MNCYWGhBO+mZdi6HgCEL7/8UjyntLRU+Mtf/iJ4eHgIzs7Owp133ilkZmZavc7FixeFKVOmCE5OToK3t7fw7LPPChUVFW38blrOI488IoSGhgoqlUro1KmTMH78eDGICIJ9XhNbaocRe70uM2fOFPz9/QWVSiUEBgYKM2fOtKqlYa/XRRAE4X//+58QEREhqNVqoWfPnsJnn31m9bg9/ty1RSYIgiBNnwwRERGRnc4ZISIiolsHwwgRERFJimGEiIiIJMUwQkRERJJiGCEiIiJJMYwQERGRpBhGiIiISFIMI0RERCQphhEiIiKSFMMIERERSYphhIiIiCTFMEJERESS+n+vc4gDPqKMswAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(torch.tensor(l).view(-1, 10).mean(1).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JQ23lyBOsiST",
    "outputId": "0e3c46a9-e58a-489f-cfb2-46e41d7c688f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def gandas(self).s))).Streamed for i += num))\n",
      "    \"\"\"\n",
      "        batchraces.\n",
      "       ...\n",
      "        \"\"\"\n",
      "        if self._jvm.SSL0 0.1:\n",
      "                                                                      name = thod:\n",
      "        \"\"\"\n",
      "        ...     (argitparam) for Jrbteast = df.tors.defaultBy recrient short \n"
     ]
    }
   ],
   "source": [
    "# 使用模型来生成文本\n",
    "context = torch.zeros((1, 10), dtype=torch.long, device=device)\n",
    "print(''.join(tok.decode(generate(model, context))))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
